From fb9b0cda058879e4c7adb9365d503682903ab7e6 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Tue, 2 Mar 2021 13:33:12 -0800 Subject: [PATCH 01/17] Fixing leak in bytecode_dispatch_test. Would be really, really nice to have an ASAN bot. --- iree/vm/bytecode_dispatch_test.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/iree/vm/bytecode_dispatch_test.cc b/iree/vm/bytecode_dispatch_test.cc index bbb6fdf14d8f..b371bb16e2b0 100644 --- a/iree/vm/bytecode_dispatch_test.cc +++ b/iree/vm/bytecode_dispatch_test.cc @@ -131,6 +131,7 @@ TEST_P(VMBytecodeDispatchTest, Check) { } } else { if (expect_failure) { + iree_status_ignore(status); GTEST_SUCCEED(); } else { GTEST_FAIL() << "Function expected success but failed with error: " From 6f56791d55e874384349dab785765cc4b42e08ff Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Tue, 16 Feb 2021 16:41:05 -0800 Subject: [PATCH 02/17] Removing unused StreamableOpInterface methods. --- iree/compiler/Dialect/Flow/IR/FlowBase.td | 11 +---------- iree/compiler/Dialect/Flow/IR/FlowOps.td | 17 ----------------- .../Dialect/Flow/Transforms/FormStreams.cpp | 4 ++-- .../Flow/Transforms/HoistUnstreamableOps.cpp | 4 ++-- 4 files changed, 5 insertions(+), 31 deletions(-) diff --git a/iree/compiler/Dialect/Flow/IR/FlowBase.td b/iree/compiler/Dialect/Flow/IR/FlowBase.td index 50a9bc4a2230..72d044f6ac4b 100644 --- a/iree/compiler/Dialect/Flow/IR/FlowBase.td +++ b/iree/compiler/Dialect/Flow/IR/FlowBase.td @@ -90,8 +90,7 @@ def FLOW_StreamableOp : OpInterface<"StreamableOpInterface"> { The goal of the stream forming process is to move as many operations that can be used within a stream into one and only using non-streamed ops as a - last resort. Ops that are isStreamOnly may force the creation of single-op - command buffers and synchronous dispatches. + last resort. }]; let methods = [ @@ -99,14 +98,6 @@ def FLOW_StreamableOp : OpInterface<"StreamableOpInterface"> { [{Returns true if the op is transfer operation (as defined by the HAL).}], "bool", "isTransfer", (ins) >, - InterfaceMethod< - [{Returns true if the op *can* be used within a stream.}], - "bool", "isUsableInStream", (ins) - >, - InterfaceMethod< - [{Returns true if the op *must* be used within a stream.}], - "bool", "isStreamOnly", (ins) - >, ]; } diff --git a/iree/compiler/Dialect/Flow/IR/FlowOps.td b/iree/compiler/Dialect/Flow/IR/FlowOps.td index 947c7b1e37c7..c6f254835923 100644 --- a/iree/compiler/Dialect/Flow/IR/FlowOps.td +++ b/iree/compiler/Dialect/Flow/IR/FlowOps.td @@ -683,8 +683,6 @@ def FLOW_DispatchOp : FLOW_PureOp<"dispatch", [ // StreamableOpInterface: bool isTransfer() { return false; } - bool isUsableInStream() { return true; } - bool isStreamOnly() { return true; } }]; let assemblyFormat = [{ @@ -722,9 +720,6 @@ def FLOW_TensorReshapeOp : FLOW_PureOp<"tensor.reshape", [ let extraClassDeclaration = [{ // StreamableOpInterface: bool isTransfer() { return true; } - bool isUsableInStream() { return true; } - // TODO(benvanik): allow out of stream to act as a shape manipulation. - bool isStreamOnly() { return true; } }]; // TODO(benvanik): canonicalize away if resulting ops don't care. @@ -809,9 +804,6 @@ def FLOW_TensorSplatOp : FLOW_PureOp<"tensor.splat", [ let extraClassDeclaration = [{ // StreamableOpInterface: bool isTransfer() { return true; } - bool isUsableInStream() { return true; } - // TODO(benvanik): allow out of stream to act as a hal.buffer.fill. - bool isStreamOnly() { return true; } }]; // TODO(benvanik): canonicalize splat+slice to smaller splat. @@ -839,9 +831,6 @@ def FLOW_TensorCloneOp : FLOW_PureOp<"tensor.clone", [ let extraClassDeclaration = [{ // StreamableOpInterface: bool isTransfer() { return true; } - bool isUsableInStream() { return true; } - // TODO(benvanik): allow out of stream to act as a hal.buffer.copy. - bool isStreamOnly() { return true; } }]; // TODO(benvanik): canonicalize away entirely in most cases. @@ -877,9 +866,6 @@ def FLOW_TensorSliceOp : FLOW_PureOp<"tensor.slice", [ let extraClassDeclaration = [{ // StreamableOpInterface: bool isTransfer() { return true; } - bool isUsableInStream() { return true; } - // TODO(benvanik): allow out of stream to act as a hal.buffer.slice. - bool isStreamOnly() { return true; } }]; // TODO(benvanik): canonicalize multiple slices (traverse upward through ssa). @@ -915,9 +901,6 @@ def FLOW_TensorUpdateOp : FLOW_PureOp<"tensor.update", [ let extraClassDeclaration = [{ // StreamableOpInterface: bool isTransfer() { return true; } - bool isUsableInStream() { return true; } - // TODO(benvanik): allow out of stream to act as a hal.buffer.copy. - bool isStreamOnly() { return true; } }]; // TODO(benvanik): canonicalize contiguous updates/across slices. diff --git a/iree/compiler/Dialect/Flow/Transforms/FormStreams.cpp b/iree/compiler/Dialect/Flow/Transforms/FormStreams.cpp index 4a86c0d92c9d..1d5b8f50903b 100644 --- a/iree/compiler/Dialect/Flow/Transforms/FormStreams.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/FormStreams.cpp @@ -43,8 +43,8 @@ namespace Flow { // Returns true if the given op can be used within a stream. static bool isStreamableOp(Operation *op) { - if (auto streamableOp = dyn_cast(op)) { - return streamableOp.isUsableInStream(); + if (isa(op)) { + return true; } if (llvm::isa(op)) { return true; diff --git a/iree/compiler/Dialect/Flow/Transforms/HoistUnstreamableOps.cpp b/iree/compiler/Dialect/Flow/Transforms/HoistUnstreamableOps.cpp index 62e9449b4c95..7e2aca509c5d 100644 --- a/iree/compiler/Dialect/Flow/Transforms/HoistUnstreamableOps.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/HoistUnstreamableOps.cpp @@ -29,8 +29,8 @@ namespace IREE { namespace Flow { static bool isStreamableOp(Operation *op) { - if (auto streamableOp = dyn_cast(op)) { - return streamableOp.isUsableInStream(); + if (isa(op)) { + return true; } if (llvm::isa(op)) { return true; From 28ada606909c3cd72c9518561b50ca81c73ee370 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Tue, 16 Feb 2021 18:25:24 -0800 Subject: [PATCH 03/17] Fixing name of flow tblgen interfaces files. --- iree/compiler/Dialect/Flow/IR/BUILD | 12 ++++++------ iree/compiler/Dialect/Flow/IR/CMakeLists.txt | 10 +++++----- iree/compiler/Dialect/Flow/IR/FlowDialect.cpp | 2 +- iree/compiler/Dialect/Flow/IR/FlowDialect.h | 2 +- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/iree/compiler/Dialect/Flow/IR/BUILD b/iree/compiler/Dialect/Flow/IR/BUILD index 825e6b37e283..0fc7a364440e 100644 --- a/iree/compiler/Dialect/Flow/IR/BUILD +++ b/iree/compiler/Dialect/Flow/IR/BUILD @@ -31,8 +31,8 @@ cc_library( srcs = [ "FlowDialect.cpp", "FlowEnums.cpp.inc", + "FlowInterfaces.cpp.inc", "FlowOpFolders.cpp", - "FlowOpInterface.cpp.inc", "FlowOpUtils.cpp", "FlowOps.cpp", "FlowOps.cpp.inc", @@ -41,7 +41,7 @@ cc_library( hdrs = [ "FlowDialect.h", "FlowEnums.h.inc", - "FlowOpInterface.h.inc", + "FlowInterfaces.h.inc", "FlowOpUtils.h", "FlowOps.h", "FlowOps.h.inc", @@ -49,7 +49,7 @@ cc_library( ], deps = [ ":FlowEnumsGen", - ":FlowOpInterfaceGen", + ":FlowInterfacesGen", ":FlowOpsGen", "//iree/compiler/Dialect/IREE/IR", "//iree/compiler/Dialect/Shape/IR", @@ -84,10 +84,10 @@ gentbl( ) gentbl( - name = "FlowOpInterfaceGen", + name = "FlowInterfacesGen", tbl_outs = [ - ("-gen-op-interface-decls", "FlowOpInterface.h.inc"), - ("-gen-op-interface-defs", "FlowOpInterface.cpp.inc"), + ("-gen-op-interface-decls", "FlowInterfaces.h.inc"), + ("-gen-op-interface-defs", "FlowInterfaces.cpp.inc"), ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "FlowBase.td", diff --git a/iree/compiler/Dialect/Flow/IR/CMakeLists.txt b/iree/compiler/Dialect/Flow/IR/CMakeLists.txt index 1cff61a9f2d4..cf8d5ecea3bb 100644 --- a/iree/compiler/Dialect/Flow/IR/CMakeLists.txt +++ b/iree/compiler/Dialect/Flow/IR/CMakeLists.txt @@ -17,7 +17,7 @@ iree_cc_library( HDRS "FlowDialect.h" "FlowEnums.h.inc" - "FlowOpInterface.h.inc" + "FlowInterfaces.h.inc" "FlowOpUtils.h" "FlowOps.h" "FlowOps.h.inc" @@ -25,8 +25,8 @@ iree_cc_library( SRCS "FlowDialect.cpp" "FlowEnums.cpp.inc" + "FlowInterfaces.cpp.inc" "FlowOpFolders.cpp" - "FlowOpInterface.cpp.inc" "FlowOpUtils.cpp" "FlowOps.cpp" "FlowOps.cpp.inc" @@ -58,12 +58,12 @@ iree_tablegen_library( iree_tablegen_library( NAME - FlowOpInterfaceGen + FlowInterfacesGen TD_FILE "FlowBase.td" OUTS - -gen-op-interface-decls FlowOpInterface.h.inc - -gen-op-interface-defs FlowOpInterface.cpp.inc + -gen-op-interface-decls FlowInterfaces.h.inc + -gen-op-interface-defs FlowInterfaces.cpp.inc ) iree_tablegen_library( diff --git a/iree/compiler/Dialect/Flow/IR/FlowDialect.cpp b/iree/compiler/Dialect/Flow/IR/FlowDialect.cpp index 6923e5a573f8..598dddd24ce6 100644 --- a/iree/compiler/Dialect/Flow/IR/FlowDialect.cpp +++ b/iree/compiler/Dialect/Flow/IR/FlowDialect.cpp @@ -30,7 +30,7 @@ namespace iree_compiler { namespace IREE { namespace Flow { -#include "iree/compiler/Dialect/Flow/IR/FlowOpInterface.cpp.inc" +#include "iree/compiler/Dialect/Flow/IR/FlowInterfaces.cpp.inc" namespace { diff --git a/iree/compiler/Dialect/Flow/IR/FlowDialect.h b/iree/compiler/Dialect/Flow/IR/FlowDialect.h index 4e916c953ae7..9622041f413d 100644 --- a/iree/compiler/Dialect/Flow/IR/FlowDialect.h +++ b/iree/compiler/Dialect/Flow/IR/FlowDialect.h @@ -24,7 +24,7 @@ namespace iree_compiler { namespace IREE { namespace Flow { -#include "iree/compiler/Dialect/Flow/IR/FlowOpInterface.h.inc" +#include "iree/compiler/Dialect/Flow/IR/FlowInterfaces.h.inc" class FlowDialect : public Dialect { public: From 15d03a435eb30da1dc6a8acf0f30305c4cf63291 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Tue, 16 Feb 2021 18:29:05 -0800 Subject: [PATCH 04/17] Adding ShapeCarryingOpInterface to allow ops to carry shape values. shapex.tie_shape now implements this interface and is no longer checked for specifically. This allows any op to carry ranked shape-compatible information. Its usage is kind of like ViewLikeOpInterface in that the op stores the dynamic dimensions and can return them as needed via the interface. --- iree/compiler/Dialect/Shape/IR/BUILD | 19 +++ iree/compiler/Dialect/Shape/IR/Builders.cpp | 130 ++++++++++++++---- iree/compiler/Dialect/Shape/IR/Builders.h | 20 ++- iree/compiler/Dialect/Shape/IR/CMakeLists.txt | 12 ++ iree/compiler/Dialect/Shape/IR/Folders.cpp | 68 +++++++-- iree/compiler/Dialect/Shape/IR/ShapeBase.td | 2 +- .../Dialect/Shape/IR/ShapeDialect.cpp | 2 + iree/compiler/Dialect/Shape/IR/ShapeDialect.h | 2 + .../Dialect/Shape/IR/ShapeInterfaces.td | 56 ++++++++ iree/compiler/Dialect/Shape/IR/ShapeOps.h | 1 + iree/compiler/Dialect/Shape/IR/ShapeOps.td | 10 +- iree/compiler/Dialect/Shape/IR/ShapeTypes.cpp | 4 + iree/compiler/Dialect/Shape/IR/ShapeTypes.h | 3 + .../Dialect/Shape/IR/test/canonicalize.mlir | 4 +- .../Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp | 12 +- 15 files changed, 298 insertions(+), 47 deletions(-) create mode 100644 iree/compiler/Dialect/Shape/IR/ShapeInterfaces.td diff --git a/iree/compiler/Dialect/Shape/IR/BUILD b/iree/compiler/Dialect/Shape/IR/BUILD index 96fe609303f5..86c577cad944 100644 --- a/iree/compiler/Dialect/Shape/IR/BUILD +++ b/iree/compiler/Dialect/Shape/IR/BUILD @@ -35,6 +35,7 @@ cc_library( "Folders.cpp", "ShapeDialect.cpp", "ShapeInterface.cpp", + "ShapeInterfaces.cpp.inc", "ShapeOps.cpp", "ShapeOps.cpp.inc", "ShapeTypes.cpp", @@ -43,11 +44,13 @@ cc_library( "Builders.h", "ShapeDialect.h", "ShapeInterface.h", + "ShapeInterfaces.h.inc", "ShapeOps.h", "ShapeOps.h.inc", "ShapeTypes.h", ], deps = [ + ":ShapeInterfacesGen", ":ShapeOpsGen", "//iree/compiler/Dialect/IREE/IR", "//iree/compiler/Utils", @@ -64,6 +67,22 @@ cc_library( ], ) +gentbl( + name = "ShapeInterfacesGen", + tbl_outs = [ + ("-gen-op-interface-decls", "ShapeInterfaces.h.inc"), + ("-gen-op-interface-defs", "ShapeInterfaces.cpp.inc"), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "ShapeInterfaces.td", + td_srcs = [ + ":td_files", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:StdOpsTdFiles", + "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td", + ], +) + gentbl( name = "ShapeOpsGen", tbl_outs = [ diff --git a/iree/compiler/Dialect/Shape/IR/Builders.cpp b/iree/compiler/Dialect/Shape/IR/Builders.cpp index 78ae3e333717..14a34862f99a 100644 --- a/iree/compiler/Dialect/Shape/IR/Builders.cpp +++ b/iree/compiler/Dialect/Shape/IR/Builders.cpp @@ -14,6 +14,7 @@ #include "iree/compiler/Dialect/Shape/IR/Builders.h" +#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h" #include "iree/compiler/Dialect/Shape/IR/ShapeOps.h" #include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" @@ -23,27 +24,68 @@ namespace mlir { namespace iree_compiler { namespace Shape { -namespace { +static Value getRankedShapeFromOpResult(Operation *op, Value resultValue, + OpBuilder &builder) { + if (!op) return nullptr; + if (auto carryingOp = dyn_cast(op)) { + return carryingOp.buildResultValueRankedShape(resultValue, builder); + } else { + return nullptr; + } +} -Value getRankedShapeFromOp(Operation *op) { - auto tieOp = llvm::dyn_cast_or_null(op); - if (!tieOp) return nullptr; - auto shape = tieOp.shape(); - if (!shape.getType().isa()) return nullptr; - return shape; +static Value getRankedShapeFromOpOperand(Operation *op, unsigned idx, + OpBuilder &builder) { + auto carryingOp = dyn_cast_or_null(op); + if (!carryingOp) { + auto value = op->getOperand(idx); + auto definingOp = value.getDefiningOp(); + if (!definingOp) return nullptr; + return getRankedShapeFromOpResult(definingOp, value, builder); + } + return carryingOp.buildOperandRankedShape(idx, builder); } -Value findRankedShapeFromUse(Value value) { - Value rs = getRankedShapeFromOp(value.getDefiningOp()); +static Value findRankedShapeFromUse(Value value, OpBuilder &builder) { + Value rs = getRankedShapeFromOpResult(value.getDefiningOp(), value, builder); if (rs) return rs; for (auto &use : value.getUses()) { - rs = getRankedShapeFromOp(use.getOwner()); + rs = getRankedShapeFromOpOperand(use.getOwner(), use.getOperandNumber(), + builder); if (rs) return rs; } return nullptr; } -} // namespace +Value buildRankedShapeForValue(Location loc, Value shapedValue, + ValueRange dynamicDims, OpBuilder &builder) { + auto shapedType = shapedValue.getType().dyn_cast(); + assert(shapedType && "only valid to call on shaped types"); + return builder.createOrFold( + loc, Shape::RankedShapeType::get(shapedType), dynamicDims); +} + +// Slices out a range of |dynamicDims| corresponding to the value at |index|. +static ValueRange sliceDynamicDims(unsigned index, ValueRange values, + ValueRange dynamicDims) { + auto valueType = values[index].getType().dyn_cast(); + assert(valueType && "must be a shaped type to get dims"); + unsigned dimsIndex = 0; + for (unsigned i = 0; i < index; ++i) { + if (auto shapedType = values[i].getType().dyn_cast()) { + dimsIndex += shapedType.getNumDynamicDims(); + } + } + return dynamicDims.slice(dimsIndex, valueType.getNumDynamicDims()); +} + +Value buildRankedShapeForValueInList(Location loc, unsigned index, + ValueRange flatValues, + ValueRange flatDynamicDims, + OpBuilder &builder) { + auto dynamicDims = sliceDynamicDims(index, flatValues, flatDynamicDims); + return buildRankedShapeForValue(loc, flatValues[index], dynamicDims, builder); +} Value buildCastInputsToResultShape(Location loc, RankedShapeType resultShapeType, @@ -121,9 +163,10 @@ Value buildDegenerateBroadcastRankedShape( } } -LogicalResult getRankedDimsFromRankedShape( - Location loc, Value rsValue, bool createIntermediateOps, - SmallVectorImpl &outDims, ConversionPatternRewriter &rewriter) { +LogicalResult getRankedDimsFromRankedShape(Location loc, Value rsValue, + bool createIntermediateOps, + SmallVectorImpl &outDims, + OpBuilder &builder) { Operation *op = rsValue.getDefiningOp(); if (op && (llvm::isa(op) || llvm::isa(op))) { @@ -134,28 +177,26 @@ LogicalResult getRankedDimsFromRankedShape( if (dynamicDimIndex >= op->getNumOperands()) { return emitError(loc, "mismatched dynamic dimensions"); } - Value remappedValue = - rewriter.getRemappedValue(op->getOperand(dynamicDimIndex++)); - if (!remappedValue) { + Value dimValue = op->getOperand(dynamicDimIndex++); + if (!dimValue) { return emitError( loc, "unable to find remapped value for ranked dim value"); } - outDims.push_back(remappedValue); + outDims.push_back(dimValue); } else { outDims.push_back( - rewriter.create(loc, rsType.getStaticDim(i))); + builder.create(loc, rsType.getStaticDim(i))); } } - return success(); } else if (createIntermediateOps) { - auto dimsOp = rewriter.create(loc, rsValue); + auto dimsOp = builder.create(loc, rsValue); outDims.resize(dimsOp.result().size()); std::copy(dimsOp.result().begin(), dimsOp.result().end(), outDims.begin()); - return success(); } else { return emitError(loc, "could not resolve ranked dimensions from metadata ops"); } + return success(); } Value buildOrFindRankedShapeForValue(Location loc, Value value, Type dimType, @@ -174,7 +215,7 @@ Value buildOrFindRankedShapeForValue(Location loc, Value value, Type dimType, // Dynamic - walk the uses to find a tie_shape op (either this op or an // immediate use). - Value rs = findRankedShapeFromUse(value); + Value rs = findRankedShapeFromUse(value, builder); if (!rs) { builder.getContext()->getDiagEngine().emit(loc, DiagnosticSeverity::Error) << "dynamically shaped value is missing a shape association via " @@ -192,6 +233,49 @@ Value buildOrFindRankedShapeForValue(Location loc, Value value, Type dimType, return rs; } +SmallVector buildOrFindDynamicDimsForValue(Location loc, Value value, + OpBuilder &builder) { + auto valueSt = value.getType().dyn_cast(); + if (!valueSt) { + builder.getContext()->getDiagEngine().emit(loc, DiagnosticSeverity::Error) + << "cannot construct shape for non shaped value: " << value.getType(); + return {}; + } + + // Bail if all dimensions are static. + if (valueSt.hasStaticShape()) { + return {}; + } + + // Dynamic - walk the uses to find a tie_shape op (either this op or an + // immediate use). + SmallVector result; + Value rs = findRankedShapeFromUse(value, builder); + if (rs) { + auto rsType = rs.getType().dyn_cast(); + if (!rsType) { + builder.getContext()->getDiagEngine().emit(loc, DiagnosticSeverity::Error) + << "dynamically shaped value is not ranked (which is not yet " + << "supported)"; + return {}; + } + for (unsigned i = 0; i < rsType.getRank(); ++i) { + if (rsType.isDimDynamic(i)) { + result.push_back(builder.createOrFold(loc, rs, i)); + } + } + } else { + // No tie information - insert std.dim ops that may later be used and + // hopefully converted to ranked shape types. + for (unsigned i = 0; i < valueSt.getRank(); ++i) { + if (valueSt.isDynamicDim(i)) { + result.push_back(builder.createOrFold(loc, value, i)); + } + } + } + return result; +} + } // namespace Shape } // namespace iree_compiler } // namespace mlir diff --git a/iree/compiler/Dialect/Shape/IR/Builders.h b/iree/compiler/Dialect/Shape/IR/Builders.h index 825c69ac0282..9951272a398a 100644 --- a/iree/compiler/Dialect/Shape/IR/Builders.h +++ b/iree/compiler/Dialect/Shape/IR/Builders.h @@ -25,6 +25,18 @@ namespace mlir { namespace iree_compiler { namespace Shape { +// Builds a ranked_shape for the given |shapedValue| with zero or more dynamic +// dims with the values taken from |dynamicDims|. +Value buildRankedShapeForValue(Location loc, Value shapedValue, + ValueRange dynamicDims, OpBuilder &builder); + +// As with buildRankedShapeForValue but by selecting out the appropriate dims +// from a flattened set of values and dynamic dims. +Value buildRankedShapeForValueInList(Location loc, unsigned index, + ValueRange flatValues, + ValueRange flatDynamicDims, + OpBuilder &builder); + // Given an arbitrary list of inputs, builds IR to obtain their shapes and // cast them to a given !shapex.ranked_shape. Statically verifiable invariants // will be checked within this call and runtime code will be emitted to verify @@ -60,6 +72,12 @@ Value buildDegenerateBroadcastRankedShape( Value buildOrFindRankedShapeForValue(Location loc, Value value, Type dimType, OpBuilder &builder); +// Returns dimension values for each dynamic dimension of the given |value|. +// |value| must be a ShapedType and may optionally have a ranked_shape tied. +// The returned value range will be empty if the shape is fully static. +SmallVector buildOrFindDynamicDimsForValue(Location loc, Value value, + OpBuilder &builder); + // Given a RankedShapeType'd value |rsValue|, populate values for all // dimensions. If |createIntermediateOps|, then if the dims cannot be resolved // by walking the IR, then a RankedDimsOp is created. If false, then the dims @@ -68,7 +86,7 @@ Value buildOrFindRankedShapeForValue(Location loc, Value value, Type dimType, LogicalResult getRankedDimsFromRankedShape(Location loc, Value rsValue, bool createIntermediateOps, SmallVectorImpl &outDims, - ConversionPatternRewriter &rewriter); + OpBuilder &builder); } // namespace Shape } // namespace iree_compiler diff --git a/iree/compiler/Dialect/Shape/IR/CMakeLists.txt b/iree/compiler/Dialect/Shape/IR/CMakeLists.txt index 84ff59ddaff2..ca7179113282 100644 --- a/iree/compiler/Dialect/Shape/IR/CMakeLists.txt +++ b/iree/compiler/Dialect/Shape/IR/CMakeLists.txt @@ -18,6 +18,7 @@ iree_cc_library( "Builders.h" "ShapeDialect.h" "ShapeInterface.h" + "ShapeInterfaces.h.inc" "ShapeOps.h" "ShapeOps.h.inc" "ShapeTypes.h" @@ -26,6 +27,7 @@ iree_cc_library( "Folders.cpp" "ShapeDialect.cpp" "ShapeInterface.cpp" + "ShapeInterfaces.cpp.inc" "ShapeOps.cpp" "ShapeOps.cpp.inc" "ShapeTypes.cpp" @@ -45,6 +47,16 @@ iree_cc_library( PUBLIC ) +iree_tablegen_library( + NAME + ShapeInterfacesGen + TD_FILE + "ShapeInterfaces.td" + OUTS + -gen-op-interface-decls ShapeInterfaces.h.inc + -gen-op-interface-defs ShapeInterfaces.cpp.inc +) + iree_tablegen_library( NAME ShapeOpsGen diff --git a/iree/compiler/Dialect/Shape/IR/Folders.cpp b/iree/compiler/Dialect/Shape/IR/Folders.cpp index fc157937292b..a6349ee9b854 100644 --- a/iree/compiler/Dialect/Shape/IR/Folders.cpp +++ b/iree/compiler/Dialect/Shape/IR/Folders.cpp @@ -52,17 +52,17 @@ LogicalResult safeCastCompatibleShapePattern( return failure(); } -LogicalResult elideTiedGetRankedShapePattern(GetRankedShapeOp op, - GetRankedShapeOp::Adaptor operands, - PatternRewriter &rewriter) { - // If the immediate predecessor is a TieShapeOp, then this op can be - // erased in favor of the input to the tie op. - auto tieOp = dyn_cast_or_null(operands.operand().getDefiningOp()); - if (!tieOp) { - return rewriter.notifyMatchFailure(op, "no associated tie_shape op"); +LogicalResult elideShapeCarryingGetRankedShapePattern( + GetRankedShapeOp op, GetRankedShapeOp::Adaptor operands, + PatternRewriter &rewriter) { + auto carryingOp = dyn_cast_or_null( + operands.operand().getDefiningOp()); + if (!carryingOp) { + return rewriter.notifyMatchFailure(op, + "no associated dynamic-shape aware op"); } - - rewriter.replaceOp(op, tieOp.shape()); + rewriter.replaceOp( + op, carryingOp.buildResultValueRankedShape(operands.operand(), rewriter)); return success(); } @@ -225,6 +225,42 @@ LogicalResult elideDuplicateTieShapePattern(TieShapeOp op, return success(); } +// Removes tie_shape ops when the operand is produced by a shape-aware op. +LogicalResult elideShapeCarryingOperandTieShapePattern( + TieShapeOp op, TieShapeOp::Adaptor operands, PatternRewriter &rewriter) { + auto definingOp = operands.operand().getDefiningOp(); + if (!definingOp) return failure(); + if (isa(definingOp)) { + return failure(); // ignore tie-shape handled above + } else if (isa(definingOp)) { + rewriter.replaceOp(op, operands.operand()); + return success(); + } else { + return failure(); + } +} + +// Reroutes uses of tie_shape ops by ops that are shape-aware or dim ops. +LogicalResult elideTieShapeUsagePattern(TieShapeOp op, + TieShapeOp::Adaptor operands, + PatternRewriter &rewriter) { + bool didAnything = false; + for (auto &use : llvm::make_early_inc_range(op.result().getUses())) { + if (auto carryingOp = dyn_cast(use.getOwner())) { + carryingOp->setOperand(use.getOperandNumber(), operands.operand()); + didAnything = true; + } else if (auto dimOp = dyn_cast(use.getOwner())) { + auto index = dimOp.getConstantIndex(); + if (index.hasValue()) { + rewriter.replaceOpWithNewOp(dimOp, op.shape(), + index.getValue()); + didAnything = true; + } + } + } + return didAnything ? success() : failure(); +} + //===----------------------------------------------------------------------===// // shapex.tie_shape //===----------------------------------------------------------------------===// @@ -232,6 +268,9 @@ LogicalResult elideDuplicateTieShapePattern(TieShapeOp op, void TieShapeOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, MLIRContext *context) { insertGreedyPattern(patterns, context, elideDuplicateTieShapePattern); + insertGreedyPattern(patterns, context, + elideShapeCarryingOperandTieShapePattern); + insertGreedyPattern(patterns, context, elideTieShapeUsagePattern); } //===----------------------------------------------------------------------===// @@ -249,7 +288,8 @@ void CastCompatibleShapeOp::getCanonicalizationPatterns( void GetRankedShapeOp::getCanonicalizationPatterns( OwningRewritePatternList &patterns, MLIRContext *context) { - insertGreedyPattern(patterns, context, elideTiedGetRankedShapePattern); + insertGreedyPattern(patterns, context, + elideShapeCarryingGetRankedShapePattern); insertGreedyPattern(patterns, context, elideDuplicateGetRankedShapePattern); insertGreedyPattern(patterns, context, elideStaticGetRankedShapePattern); } @@ -358,7 +398,11 @@ void populateFoldConversionPatterns(MLIRContext *context, insertConversionPattern(patterns, context, elideDuplicateGetRankedShapePattern); insertConversionPattern(patterns, context, elideDuplicateTieShapePattern); - insertConversionPattern(patterns, context, elideTiedGetRankedShapePattern); + insertConversionPattern(patterns, context, + elideShapeCarryingOperandTieShapePattern); + insertConversionPattern(patterns, context, elideTieShapeUsagePattern); + insertConversionPattern(patterns, context, + elideShapeCarryingGetRankedShapePattern); insertConversionPattern(patterns, context, expandRankedShapeDimsPattern); insertConversionPattern(patterns, context, identityMakeRankedShapePattern); insertConversionPattern(patterns, context, elideStaticGetRankedShapePattern); diff --git a/iree/compiler/Dialect/Shape/IR/ShapeBase.td b/iree/compiler/Dialect/Shape/IR/ShapeBase.td index 180c7aca47fb..d577ad3aff76 100644 --- a/iree/compiler/Dialect/Shape/IR/ShapeBase.td +++ b/iree/compiler/Dialect/Shape/IR/ShapeBase.td @@ -18,7 +18,7 @@ include "mlir/IR/OpBase.td" //===----------------------------------------------------------------------===// -// Shape dialect. +// Shape dialect //===----------------------------------------------------------------------===// // TODO(b/143787186): rename when old dialects are removed. diff --git a/iree/compiler/Dialect/Shape/IR/ShapeDialect.cpp b/iree/compiler/Dialect/Shape/IR/ShapeDialect.cpp index 7586c5fb3d03..71511ec5b369 100644 --- a/iree/compiler/Dialect/Shape/IR/ShapeDialect.cpp +++ b/iree/compiler/Dialect/Shape/IR/ShapeDialect.cpp @@ -32,6 +32,8 @@ namespace mlir { namespace iree_compiler { +#include "iree/compiler/Dialect/Shape/IR/ShapeInterfaces.cpp.inc" + // Used to control inlining behavior. struct ShapeInlinerInterface : public DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface; diff --git a/iree/compiler/Dialect/Shape/IR/ShapeDialect.h b/iree/compiler/Dialect/Shape/IR/ShapeDialect.h index 6465384204c2..f865f84ce6ee 100644 --- a/iree/compiler/Dialect/Shape/IR/ShapeDialect.h +++ b/iree/compiler/Dialect/Shape/IR/ShapeDialect.h @@ -21,6 +21,8 @@ namespace mlir { namespace iree_compiler { +#include "iree/compiler/Dialect/Shape/IR/ShapeInterfaces.h.inc" + class ShapeDialect : public Dialect { public: explicit ShapeDialect(MLIRContext* context); diff --git a/iree/compiler/Dialect/Shape/IR/ShapeInterfaces.td b/iree/compiler/Dialect/Shape/IR/ShapeInterfaces.td new file mode 100644 index 000000000000..d9df80db260d --- /dev/null +++ b/iree/compiler/Dialect/Shape/IR/ShapeInterfaces.td @@ -0,0 +1,56 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_DIALECT_SHAPE_INTERFACES +#define IREE_DIALECT_SHAPE_INTERFACES + +include "mlir/IR/OpBase.td" + +//===----------------------------------------------------------------------===// +// Op interfaces +//===----------------------------------------------------------------------===// + +def Shape_ShapeCarryingOpInterface : OpInterface<"ShapeCarryingInterface"> { + let description = [{ + Interface for ops that interact with dynamically shaped inputs and outputs. + Such ops are able to materialize RankedShapes on demand for any operand or + result that derives from ShapedType. + }]; + + let methods = [ + StaticInterfaceMethod< + [{Returns a RankedShape for the given shaped result value.}], + "Value", "buildResultValueRankedShape", + (ins "Value":$result, "OpBuilder &":$builder), [{ + auto carryingOp = dyn_cast(result.getDefiningOp()); + for (unsigned idx = 0; idx < carryingOp->getNumResults(); ++idx) { + if (carryingOp->getResult(idx) == result) { + return carryingOp.buildResultRankedShape(idx, builder); + } + } + llvm_unreachable("called on result not from a shape-aware-op"); + }] + >, + InterfaceMethod< + [{Returns a RankedShape for the given shaped operand index.}], + "Value", "buildOperandRankedShape", (ins "unsigned":$idx, "OpBuilder &":$builder) + >, + InterfaceMethod< + [{Returns a RankedShape for the given shaped result index.}], + "Value", "buildResultRankedShape", (ins "unsigned":$idx, "OpBuilder &":$builder) + >, + ]; +} + +#endif // IREE_DIALECT_SHAPE_INTERFACES diff --git a/iree/compiler/Dialect/Shape/IR/ShapeOps.h b/iree/compiler/Dialect/Shape/IR/ShapeOps.h index ffd50296186e..ee4c30bf02a7 100644 --- a/iree/compiler/Dialect/Shape/IR/ShapeOps.h +++ b/iree/compiler/Dialect/Shape/IR/ShapeOps.h @@ -15,6 +15,7 @@ #ifndef IREE_COMPILER_DIALECT_SHAPE_IR_SHAPEOPS_H_ #define IREE_COMPILER_DIALECT_SHAPE_IR_SHAPEOPS_H_ +#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h" #include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinTypes.h" diff --git a/iree/compiler/Dialect/Shape/IR/ShapeOps.td b/iree/compiler/Dialect/Shape/IR/ShapeOps.td index 617e526fd5e7..164f66744f81 100644 --- a/iree/compiler/Dialect/Shape/IR/ShapeOps.td +++ b/iree/compiler/Dialect/Shape/IR/ShapeOps.td @@ -16,6 +16,7 @@ #define IREE_DIALECT_SHAPE_OPS include "iree/compiler/Dialect/Shape/IR/ShapeBase.td" +include "iree/compiler/Dialect/Shape/IR/ShapeInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" @@ -39,8 +40,9 @@ class Shape_PureOp traits = []> : //===----------------------------------------------------------------------===// def Shape_TieShapeOp : Shape_PureOp<"tie_shape", [ + Shape_ShapeCarryingOpInterface, AllTypesMatch<["operand", "result"]>, - DeclareOpInterfaceMethods + DeclareOpInterfaceMethods, ]> { let summary = "Ties a tensor and a shape together."; let description = [{ @@ -57,6 +59,12 @@ def Shape_TieShapeOp : Shape_PureOp<"tie_shape", [ let assemblyFormat = "operands attr-dict `:` type($operand) `,` type($shape)"; + let extraClassDeclaration = [{ + // ShapeCarryingInterface: + Value buildOperandRankedShape(unsigned, OpBuilder) { return shape(); } + Value buildResultRankedShape(unsigned, OpBuilder) { return shape(); } + }]; + let verifier = [{ return verify$cppClass(*this); }]; let hasCanonicalizer = 1; diff --git a/iree/compiler/Dialect/Shape/IR/ShapeTypes.cpp b/iree/compiler/Dialect/Shape/IR/ShapeTypes.cpp index 2dbec24a5024..ce549223c627 100644 --- a/iree/compiler/Dialect/Shape/IR/ShapeTypes.cpp +++ b/iree/compiler/Dialect/Shape/IR/ShapeTypes.cpp @@ -79,6 +79,10 @@ RankedShapeType RankedShapeType::getChecked( return Base::getChecked(emitError, context, dims); } +RankedShapeType RankedShapeType::get(ShapedType shapedType) { + return Base::get(shapedType.getContext(), shapedType.getShape()); +} + LogicalResult RankedShapeType::verify( function_ref emitError, ArrayRef dims) { for (auto dim : dims) { diff --git a/iree/compiler/Dialect/Shape/IR/ShapeTypes.h b/iree/compiler/Dialect/Shape/IR/ShapeTypes.h index b48aafb6da01..e8d43309b1c6 100644 --- a/iree/compiler/Dialect/Shape/IR/ShapeTypes.h +++ b/iree/compiler/Dialect/Shape/IR/ShapeTypes.h @@ -47,6 +47,9 @@ class RankedShapeType : public Type::TypeBase emitError, MLIRContext *context, ArrayRef dims); + // Derives a RankedShapeType from a ShapedType. + static RankedShapeType get(ShapedType shapedType); + // Verifies construction invariants and issues errors/warnings. static LogicalResult verify(function_ref emitError, ArrayRef dims); diff --git a/iree/compiler/Dialect/Shape/IR/test/canonicalize.mlir b/iree/compiler/Dialect/Shape/IR/test/canonicalize.mlir index dfef6a2c6a72..73070967dc46 100644 --- a/iree/compiler/Dialect/Shape/IR/test/canonicalize.mlir +++ b/iree/compiler/Dialect/Shape/IR/test/canonicalize.mlir @@ -123,8 +123,8 @@ func @elideDuplicateTieShapePattern_match(%arg0 : tensor, %arg1 : !shapex func @elideDuplicateTieShapePattern_different_shapes(%arg0 : tensor, %arg1 : !shapex.ranked_shape<[?]>, %arg2 : !shapex.ranked_shape<[?]>) -> (tensor) { %0 = shapex.tie_shape %arg0, %arg1 : tensor, !shapex.ranked_shape<[?]> %1 = shapex.tie_shape %0, %arg2 : tensor, !shapex.ranked_shape<[?]> - // CHECK: %[[T:.+]] = shapex.tie_shape %[[ARGT]], %[[ARGRS1]] - // CHECK: shapex.tie_shape %[[T]], %[[ARGRS2]] + // CHECK: %[[T:.+]] = shapex.tie_shape %[[ARGT]], %[[ARGRS2]] + // CHECK: return %[[T]] return %1 : tensor } diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp index 6c2f44266bb4..c32b9f6ca870 100644 --- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp +++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp @@ -874,11 +874,6 @@ struct ConvertOpConversion : public OpConversionPattern { void populateHLOToVMLAPatterns(MLIRContext *context, OwningRewritePatternList &patterns, TypeConverter &typeConverter) { - // We rely on some additional HLO->std patterns and assume they - // have been run already. In case they haven't we provide them here (useful - // for standalone conversion testing). - mhlo::PopulateMhloToStdPatterns(&patterns, context); - // mhlo.convolution. populateHLOConvToVMLAPatterns(context, patterns, typeConverter); @@ -995,8 +990,11 @@ void populateHLOToVMLAPatterns(MLIRContext *context, // runtime. patterns.insert(context); - // TODO(benvanik): add missing ops: - // - ConvOp + // We rely on some additional HLO->std patterns and assume they + // have been run already. In case they haven't we provide them here (useful + // for standalone conversion testing). We run them last so that other patterns + // have a chance to handle the HLO conversions first. + mhlo::PopulateMhloToStdPatterns(&patterns, context); } } // namespace iree_compiler From baf8a2dae983683ca9b163d5a2c21578db9d6dca Mon Sep 17 00:00:00 2001 From: "Ahmed S. Taei" Date: Wed, 10 Mar 2021 20:54:58 -0800 Subject: [PATCH 05/17] Enable mhlo.while e2e/xla_ops test for linalg-on-tensors path (#5054) --- iree/test/e2e/xla_ops/BUILD | 4 ++-- iree/test/e2e/xla_ops/CMakeLists.txt | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/iree/test/e2e/xla_ops/BUILD b/iree/test/e2e/xla_ops/BUILD index 4967da1ef81b..bd6eff8b6d2a 100644 --- a/iree/test/e2e/xla_ops/BUILD +++ b/iree/test/e2e/xla_ops/BUILD @@ -185,7 +185,7 @@ iree_check_single_backend_test_suite( # https://github.com/google/iree/issues/4079 # "torch_index_select.mlir", "transpose.mlir", - # "while.mlir", + "while.mlir", ], compiler_flags = [ "-iree-flow-dispatch-linalg-on-tensors", @@ -246,7 +246,7 @@ iree_check_single_backend_test_suite( # https://github.com/google/iree/issues/4079 # "torch_index_select.mlir", "transpose.mlir", - # "while.mlir", + "while.mlir", ], compiler_flags = [ "-iree-flow-dispatch-linalg-on-tensors", diff --git a/iree/test/e2e/xla_ops/CMakeLists.txt b/iree/test/e2e/xla_ops/CMakeLists.txt index 90ed6bd88079..9a413975afb2 100644 --- a/iree/test/e2e/xla_ops/CMakeLists.txt +++ b/iree/test/e2e/xla_ops/CMakeLists.txt @@ -166,6 +166,7 @@ iree_check_single_backend_test_suite( "subtract.mlir" "tanh.mlir" "transpose.mlir" + "while.mlir" TARGET_BACKEND "dylib-llvm-aot" DRIVER @@ -213,6 +214,7 @@ iree_check_single_backend_test_suite( "subtract.mlir" "tanh.mlir" "transpose.mlir" + "while.mlir" TARGET_BACKEND "vulkan-spirv" DRIVER From 0ddc9e975f3b9d153ee19515938a6fa379bb6025 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Tue, 16 Feb 2021 18:40:27 -0800 Subject: [PATCH 06/17] Making flow ops shape-aware. This removes most of the tie_shape uses outside of dispatch regions. It required reworking ClosureOpDce so that the shape-aware ops can still be optimized. Future changes will start edging us towards handling shapes with this interface such that the ties are only required when interoping with code that is not shape-aware. --- .../Conversion/HLOToFlow/ConvertHLOToFlow.cpp | 2 +- iree/compiler/Dialect/Flow/IR/BUILD | 2 +- iree/compiler/Dialect/Flow/IR/CMakeLists.txt | 2 +- iree/compiler/Dialect/Flow/IR/FlowBase.td | 23 +- .../Dialect/Flow/IR/FlowInterfaces.td | 116 ++++ .../Dialect/Flow/IR/FlowOpFolders.cpp | 183 +++-- iree/compiler/Dialect/Flow/IR/FlowOpUtils.cpp | 334 ++++++--- iree/compiler/Dialect/Flow/IR/FlowOpUtils.h | 83 ++- iree/compiler/Dialect/Flow/IR/FlowOps.cpp | 639 +++++++++++++----- iree/compiler/Dialect/Flow/IR/FlowOps.h | 1 + iree/compiler/Dialect/Flow/IR/FlowOps.td | 211 ++++-- .../Dialect/Flow/IR/test/dispatch_ops.mlir | 6 +- .../Flow/IR/test/dispatch_region_folding.mlir | 2 +- .../Flow/IR/test/dispatch_regions.mlir | 2 +- .../Flow/IR/test/dispatch_workgroups.mlir | 15 +- .../IR/test/dispatch_workgroups_folding.mlir | 18 +- .../Dialect/Flow/IR/test/stream_folding.mlir | 71 +- .../Dialect/Flow/IR/test/stream_ops.mlir | 6 +- .../Dialect/Flow/IR/test/tensor_folding.mlir | 53 +- .../Dialect/Flow/IR/test/tensor_ops.mlir | 60 ++ iree/compiler/Dialect/Flow/IR/test/types.mlir | 20 +- .../Flow/Transforms/CreateBenchmarkFuncs.cpp | 2 +- .../Transforms/DispatchLinalgOnTensors.cpp | 59 +- .../Dialect/Flow/Transforms/FormStreams.cpp | 66 +- .../Transforms/IdentifyDispatchRegions2.cpp | 8 +- .../Transforms/OutlineDispatchRegions.cpp | 19 +- .../Transforms/OutlineDispatchRegions2.cpp | 44 +- .../Dialect/Flow/Transforms/Passes.cpp | 12 +- .../RematerializeDispatchConstants.cpp | 31 +- .../test/create_benchmark_funcs.mlir | 10 +- .../test/deduplicate_executables.mlir | 30 +- .../test/dispatch_linalg_on_tensors.mlir | 25 +- .../dispatch_linalg_on_tensors_dynamic.mlir | 64 +- .../fold_compatible_dispatch_regions.mlir | 10 +- .../Flow/Transforms/test/form_streams.mlir | 216 +++--- .../test/hoist_unstreamable_ops.mlir | 34 +- .../test/identify_dispatch_regions.mlir | 18 +- ...ispatch_regions2_enable_matmul_fusion.mlir | 2 +- .../test/identify_dispatch_regions2_hlo.mlir | 19 +- .../identify_dispatch_regions2_linalg.mlir | 26 +- .../identify_dispatch_regions2_shapes.mlir | 2 +- ...identify_dispatch_regions2_std_fusion.mlir | 2 +- .../test/inject_dispatch_tracing.mlir | 12 +- .../test/outline_dispatch_regions2.mlir | 38 +- ...tline_dispatch_regions_ranked_dynamic.mlir | 44 -- .../rematerialize_dispatch_constants.mlir | 13 +- .../Flow/Transforms/test/transformation.mlir | 78 +-- .../Conversion/FlowToHAL/ConvertStreamOps.cpp | 28 +- .../Conversion/FlowToHAL/ConvertTensorOps.cpp | 6 +- .../Conversion/FlowToHAL/test/stream_ops.mlir | 52 +- .../HAL/Target/CUDA/test/smoketest.mlir | 41 +- .../HAL/Target/VMLA/test/i1_types.mlir | 5 +- .../Dialect/Shape/IR/ShapeInterfaces.td | 28 +- .../Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp | 4 +- .../e2e/hackability/flow_partitioned.mlir | 2 +- 55 files changed, 1803 insertions(+), 1096 deletions(-) create mode 100644 iree/compiler/Dialect/Flow/IR/FlowInterfaces.td delete mode 100644 iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions_ranked_dynamic.mlir diff --git a/iree/compiler/Dialect/Flow/Conversion/HLOToFlow/ConvertHLOToFlow.cpp b/iree/compiler/Dialect/Flow/Conversion/HLOToFlow/ConvertHLOToFlow.cpp index 974fad9854a0..1a988b4ea0fd 100644 --- a/iree/compiler/Dialect/Flow/Conversion/HLOToFlow/ConvertHLOToFlow.cpp +++ b/iree/compiler/Dialect/Flow/Conversion/HLOToFlow/ConvertHLOToFlow.cpp @@ -53,7 +53,7 @@ struct DynamicUpdateSliceOpLowering rewriter.getIndexType()); })); rewriter.replaceOpWithNewOp( - op, op.getResult().getType(), op.update(), op.operand(), startIndices); + op, op.operand(), startIndices, op.update()); return success(); } }; diff --git a/iree/compiler/Dialect/Flow/IR/BUILD b/iree/compiler/Dialect/Flow/IR/BUILD index 0fc7a364440e..e6c4c6a630c5 100644 --- a/iree/compiler/Dialect/Flow/IR/BUILD +++ b/iree/compiler/Dialect/Flow/IR/BUILD @@ -90,7 +90,7 @@ gentbl( ("-gen-op-interface-defs", "FlowInterfaces.cpp.inc"), ], tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "FlowBase.td", + td_file = "FlowInterfaces.td", td_srcs = [ ":td_files", "//iree/compiler/Dialect/IREE/IR:td_files", diff --git a/iree/compiler/Dialect/Flow/IR/CMakeLists.txt b/iree/compiler/Dialect/Flow/IR/CMakeLists.txt index cf8d5ecea3bb..581bae16396e 100644 --- a/iree/compiler/Dialect/Flow/IR/CMakeLists.txt +++ b/iree/compiler/Dialect/Flow/IR/CMakeLists.txt @@ -60,7 +60,7 @@ iree_tablegen_library( NAME FlowInterfacesGen TD_FILE - "FlowBase.td" + "FlowInterfaces.td" OUTS -gen-op-interface-decls FlowInterfaces.h.inc -gen-op-interface-defs FlowInterfaces.cpp.inc diff --git a/iree/compiler/Dialect/Flow/IR/FlowBase.td b/iree/compiler/Dialect/Flow/IR/FlowBase.td index 72d044f6ac4b..9e1b62ae8bb0 100644 --- a/iree/compiler/Dialect/Flow/IR/FlowBase.td +++ b/iree/compiler/Dialect/Flow/IR/FlowBase.td @@ -79,28 +79,6 @@ class FLOW_Op traits = []> : let printer = [{ return print$cppClass(p, *this); }]; } -def FLOW_StreamableOp : OpInterface<"StreamableOpInterface"> { - let description = [{ - Interface for ops that can be used within a stream. - - Some ops can exist both within a stream and outside of a stream. This allows - optimizations to place certain ops such that they are performed in a - synchronous (outside of a stream) or asynchronous (inside of a stream) - fashion. - - The goal of the stream forming process is to move as many operations that - can be used within a stream into one and only using non-streamed ops as a - last resort. - }]; - - let methods = [ - InterfaceMethod< - [{Returns true if the op is transfer operation (as defined by the HAL).}], - "bool", "isTransfer", (ins) - >, - ]; -} - //===----------------------------------------------------------------------===// // Flow dialect types //===----------------------------------------------------------------------===// @@ -108,6 +86,7 @@ def FLOW_StreamableOp : OpInterface<"StreamableOpInterface"> { def FLOW_PrimitiveType : AnyTypeOf<[Index, AnySignlessInteger, AnyFloat]>; def FLOW_Dim : TypeAlias; +def FLOW_ShapeDynamicDims : Variadic; def FLOW_Tensor : TypeAlias; diff --git a/iree/compiler/Dialect/Flow/IR/FlowInterfaces.td b/iree/compiler/Dialect/Flow/IR/FlowInterfaces.td new file mode 100644 index 000000000000..dba5aeafd140 --- /dev/null +++ b/iree/compiler/Dialect/Flow/IR/FlowInterfaces.td @@ -0,0 +1,116 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_DIALECT_FLOW_INTERFACES +#define IREE_DIALECT_FLOW_INTERFACES + +include "iree/compiler/Dialect/IREE/IR/IREEBase.td" + +//===----------------------------------------------------------------------===// +// IREE::Flow::ClosureOpInterface +//===----------------------------------------------------------------------===// + +def FLOW_ClosureOpInterface : OpInterface<"ClosureOpInterface"> { + let description = [{ + Interface for ops that follow the Flow dialect closure semantics (explicit + captures, dynamic-shape awareness, and normal operand/result SSA behavior). + + Implementing this interface enables optimizations that perform manipulation + across the closure capture boundary (outside of the op <-> regions within + the op). + }]; + + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Returns the body region of the closure (may have multiple blocks). + }], + /*retTy=*/"Region &", + /*methodName=*/"getClosureBodyRegion", + /*args=*/(ins), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + return this->getOperation()->getRegion(0); + }] + >, + InterfaceMethod< + /*desc=*/[{Returns all closure operand values.}], + /*retTy=*/"Operation::operand_range", + /*methodName=*/"getClosureOperands", + /*args=*/(ins) + >, + InterfaceMethod< + /*desc=*/[{Returns all closure result values.}], + /*retTy=*/"Operation::result_range", + /*methodName=*/"getClosureResults", + /*args=*/(ins) + >, + InterfaceMethod< + /*desc=*/[{ + Returns true if the given operation can exist in the closure. + Not all operations that a closure can contain are guaranteed to be folded + into the closure, such as when the operation may have side-effects. + }], + /*retTy=*/"bool", + /*methodName=*/"canClosureContainOp", + /*args=*/(ins "Operation *":$op) + >, + InterfaceMethod< + /*desc=*/[{ + Clones the op while removing specified operands and results. + The body of the op will be transferred to the new op and the entry block + will have its arguments removed. + + The returned op will be free standing. Callers must insert it into a block + where desired (most often just replacing the current op). + }], + /*retTy=*/"ClosureOpInterface", + /*methodName=*/"cloneReplacementExcludingOperandsAndResults", + /*args=*/(ins "ArrayRef":$excludedOperandIndices, + "ArrayRef":$excludedResultIndices) + >, + ]; +} + +//===----------------------------------------------------------------------===// +// IREE::Flow::StreamableOpInterface +//===----------------------------------------------------------------------===// + +def FLOW_StreamableOp : OpInterface<"StreamableOpInterface"> { + let description = [{ + Interface for ops that can be used within a stream. + + Some ops can exist both within a stream and outside of a stream. This allows + optimizations to place certain ops such that they are performed in a + synchronous (outside of a stream) or asynchronous (inside of a stream) + fashion. + + The goal of the stream forming process is to move as many operations that + can be used within a stream into one and only using non-streamed ops as a + last resort. + }]; + + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Returns true if the op is transfer operation (as defined by the HAL). + }], + /*retTy=*/"bool", + /*methodName=*/"isTransfer", + /*args=*/(ins) + >, + ]; +} + +#endif // IREE_DIALECT_FLOW_INTERFACES diff --git a/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp b/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp index fde4c282af81..a3dbd3da19fe 100644 --- a/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp +++ b/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp @@ -24,6 +24,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringExtras.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Matchers.h" @@ -41,41 +42,56 @@ namespace IREE { namespace Flow { //===----------------------------------------------------------------------===// -// Streams +// Folding utilities //===----------------------------------------------------------------------===// -namespace { - -// Optimizes stream fragment arguments by: -// - Removing any that are not used in the body -// - Deduping arguments that refer to the same Value -struct DceStreamFragment : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +// Returns a new set of dynamic dimensions for a shape carrying op when a type +// is being changed. This attempts to reuse the existing dimension values if +// they are available and will drop/insert new ones as required. +static SmallVector refreshDimsOnTypeChange( + Operation *op, Type oldType, Type newType, ValueRange oldDims, + PatternRewriter &rewriter) { + if (oldType == newType) return llvm::to_vector<4>(oldDims); + + // Build an expanded list of all the dims - constants will be nullptr. + // This lets us map back the new types without worrying about whether some + // subset become static or dynamic. + auto oldShapedType = oldType.cast(); + SmallVector allOldDims(oldShapedType.getRank()); + for (unsigned i = 0; i < oldShapedType.getRank(); ++i) { + if (oldShapedType.isDynamicDim(i)) { + allOldDims[i] = oldDims.front(); + oldDims = oldDims.drop_front(); + } + } - LogicalResult matchAndRewrite(ExStreamFragmentOp op, - PatternRewriter &rewriter) const override { - if (op.body().empty()) return failure(); - ClosureOpDce dce(op, op.body().front(), /*variadicOffset=*/0); - if (!dce.needsOptimization()) return failure(); - - bool newOperation = dce.needsNewOperation(); - if (!newOperation) { - rewriter.startRootUpdate(op); - dce.optimize(rewriter); - rewriter.finalizeRootUpdate(op); - } else { - dce.optimize(rewriter, /*eraseOriginal=*/false); - rewriter.eraseOp(op); + auto newShapedType = newType.cast(); + SmallVector newDims; + for (unsigned i = 0; i < newShapedType.getRank(); ++i) { + if (newShapedType.isDynamicDim(i)) { + auto oldValue = allOldDims[i]; + if (oldValue) { + // Old value valid; reuse. + newDims.push_back(oldValue); + } else { + // Dimension has changed to be dynamic; insert a constant to use. + // This sometimes happens during folding of casts and usually is cleaned + // up pretty quickly. + newDims.push_back(rewriter.createOrFold( + op->getLoc(), oldShapedType.getDimSize(i))); + } } - return success(); } -}; + return newDims; +} -} // namespace +//===----------------------------------------------------------------------===// +// Streams +//===----------------------------------------------------------------------===// void ExStreamFragmentOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert>(context); } //===----------------------------------------------------------------------===// @@ -223,70 +239,14 @@ void VariableStoreIndirectOp::getCanonicalizationPatterns( // Dispatch ops //===----------------------------------------------------------------------===// -namespace { - -struct DceDispatchRegion : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(DispatchRegionOp op, - PatternRewriter &rewriter) const override { - if (op.body().empty()) return failure(); - ClosureOpDce dce(op, op.body().front(), /*variadicOffset=*/1); - if (!dce.needsOptimization()) return failure(); - - bool newOperation = dce.needsNewOperation(); - if (!newOperation) { - rewriter.startRootUpdate(op); - dce.optimize(rewriter); - rewriter.finalizeRootUpdate(op); - } else { - dce.optimize(rewriter, /*eraseOriginal=*/false); - rewriter.eraseOp(op); - } - return success(); - } -}; - -} // namespace - void DispatchRegionOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert>(context); } -namespace { - -struct DceDispatchWorkgroups : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(DispatchWorkgroupsOp op, - PatternRewriter &rewriter) const override { - if (op.body().empty()) return failure(); - ClosureOpDce dce(op, op.body().front(), - /*variadicOffset=*/op.workgroup_count().size()); - if (!dce.needsOptimization()) return failure(); - - bool newOperation = dce.needsNewOperation(); - if (!newOperation) { - rewriter.startRootUpdate(op); - dce.optimize(rewriter); - rewriter.finalizeRootUpdate(op); - } else { - dce.optimize(rewriter, /*eraseOriginal=*/false); - rewriter.eraseOp(op); - } - return success(); - } -}; - -} // namespace - void DispatchWorkgroupsOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { - // TODO(benvanik): add DCE support; it's tricky here because the ClosureOpDce - // assumes 1:1 operands/results between the outer op and inner region, but we - // don't do that there. - // results.insert(context); + results.insert>(context); } //===----------------------------------------------------------------------===// @@ -444,14 +404,6 @@ OpFoldResult TensorReshapeOp::fold(ArrayRef operands) { // No-op. return source(); } - - // Skip intermediate reshapes. - if (auto definingOp = - dyn_cast_or_null(source().getDefiningOp())) { - setOperand(definingOp.getOperand()); - return result(); - } - return {}; } @@ -606,12 +558,15 @@ static ElementsAttr tensorUpdate(ElementsAttr update, ElementsAttr target, } OpFoldResult TensorUpdateOp::fold(ArrayRef operands) { - auto indices = operands.drop_front(2); + auto targetIndex = getODSOperandIndexAndLength(0).first; + auto startIndices = getODSOperandIndexAndLength(2); + auto updateIndex = getODSOperandIndexAndLength(3).first; + auto indices = operands.slice(startIndices.first, startIndices.second); bool allIndicesConstant = llvm::count(indices, nullptr) == 0; - if (operands[0] && operands[1] && allIndicesConstant) { + if (operands[updateIndex] && operands[targetIndex] && allIndicesConstant) { // Fully constant arguments so we can perform the update here. - return tensorUpdate(operands[0].cast(), - operands[1].cast(), indices); + return tensorUpdate(operands[updateIndex].cast(), + operands[targetIndex].cast(), indices); } else { // Replace the entire tensor when the sizes match. auto updateType = update().getType().cast(); @@ -624,6 +579,42 @@ OpFoldResult TensorUpdateOp::fold(ArrayRef operands) { return {}; } +namespace { + +// When the target tensor is a result of a tensor.cast operation, the op needs +// to be updated to use the source of the cast as the target tensor. +struct FoldTensorUpdateOpWithCasts : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TensorUpdateOp updateOp, + PatternRewriter &rewriter) const override { + auto targetCastOp = updateOp.target().getDefiningOp(); + auto updateCastOp = updateOp.update().getDefiningOp(); + if (!targetCastOp && !updateCastOp) return failure(); + auto target = (targetCastOp ? targetCastOp.source() : updateOp.target()); + auto update = (updateCastOp ? updateCastOp.source() : updateOp.update()); + auto newOp = rewriter.create( + updateOp.getLoc(), target.getType(), target, + refreshDimsOnTypeChange(updateOp, updateOp.target().getType(), + target.getType(), updateOp.target_dims(), + rewriter), + updateOp.start_indices(), update, + refreshDimsOnTypeChange(updateOp, updateOp.update().getType(), + update.getType(), updateOp.update_dims(), + rewriter)); + rewriter.replaceOpWithNewOp( + updateOp, updateOp.getResult().getType(), newOp.getResult()); + return success(); + } +}; + +} // namespace + +void TensorUpdateOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + } // namespace Flow } // namespace IREE } // namespace iree_compiler diff --git a/iree/compiler/Dialect/Flow/IR/FlowOpUtils.cpp b/iree/compiler/Dialect/Flow/IR/FlowOpUtils.cpp index eecec9acebe6..d188f7bbc3f2 100644 --- a/iree/compiler/Dialect/Flow/IR/FlowOpUtils.cpp +++ b/iree/compiler/Dialect/Flow/IR/FlowOpUtils.cpp @@ -14,131 +14,297 @@ #include "iree/compiler/Dialect/Flow/IR/FlowOpUtils.h" -#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" namespace mlir { namespace iree_compiler { namespace IREE { namespace Flow { -Operation *cloneWithNewResultTypes(Operation *op, TypeRange newResultTypes) { - OperationState state(op->getLoc(), op->getName()); - state.addOperands(op->getOperands()); - state.addTypes(newResultTypes); - state.addSuccessors(op->getSuccessors()); - state.addAttributes(op->getAttrs()); - for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) { - state.addRegion(); +//------------------------------------------------------------------------------ +// Closure optimization +//------------------------------------------------------------------------------ + +void excludeClosureOperandsAndResults( + SmallVector &operandValues, + ArrayRef excludedOperandIndices, + SmallVector &resultTypes, + ArrayRef excludedResultIndices) { + SmallVector oldOperandValues = operandValues; + operandValues.clear(); + for (auto it : llvm::enumerate(oldOperandValues)) { + if (!llvm::count(excludedOperandIndices, it.index())) { + operandValues.push_back(it.value()); + } } - Operation *newOp = Operation::create(state); - for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) { - newOp->getRegion(i).takeBody(op->getRegion(i)); + SmallVector oldResultTypes = resultTypes; + resultTypes.clear(); + for (auto it : llvm::enumerate(oldResultTypes)) { + if (!llvm::count(excludedResultIndices, it.index())) { + resultTypes.push_back(it.value()); + } } - return newOp; } -//------------------------------------------------------------------------------ -// ClosureOpDce -//------------------------------------------------------------------------------ +void excludeClosureOperandsAndResults( + SmallVector &operandValues, SmallVector &operandDims, + ArrayRef excludedOperandIndices, + SmallVector &resultTypes, SmallVector &resultDims, + ArrayRef excludedResultIndices) { + SmallVector oldOperandValues = operandValues; + SmallVector oldOperandDims = operandDims; + operandValues.clear(); + operandDims.clear(); + auto remainingOperandDims = llvm::makeArrayRef(oldOperandDims); + for (auto it : llvm::enumerate(oldOperandValues)) { + unsigned numDynamicDims = 0; + auto type = it.value().getType(); + if (auto shapedType = type.dyn_cast()) { + numDynamicDims = shapedType.getNumDynamicDims(); + } + if (!llvm::count(excludedOperandIndices, it.index())) { + operandValues.push_back(it.value()); + for (auto dim : remainingOperandDims.take_front(numDynamicDims)) { + operandDims.push_back(dim); + } + } + remainingOperandDims = remainingOperandDims.drop_front(numDynamicDims); + } + + SmallVector oldResultTypes = resultTypes; + SmallVector oldResultDims = resultDims; + resultTypes.clear(); + resultDims.clear(); + auto remainingResultDims = llvm::makeArrayRef(oldResultDims); + for (auto it : llvm::enumerate(oldResultTypes)) { + unsigned numDynamicDims = 0; + auto type = it.value(); + if (auto shapedType = type.dyn_cast()) { + numDynamicDims = shapedType.getNumDynamicDims(); + } + if (!llvm::count(excludedResultIndices, it.index())) { + resultTypes.push_back(type); + for (auto dim : remainingResultDims.take_front(numDynamicDims)) { + resultDims.push_back(dim); + } + } + remainingResultDims = remainingResultDims.drop_front(numDynamicDims); + } +} -ClosureOpDce::ClosureOpDce(Operation *closureOp, Block &entryBlock, - unsigned variadicOffset) - : closureOp(closureOp), - entryBlock(entryBlock), - variadicOffset(variadicOffset), - blockArgReplacements(entryBlock.getNumArguments()) { - assert(closureOp->getNumOperands() == - entryBlock.getNumArguments() + variadicOffset); +void eraseRegionResults(Region ®ion, + ArrayRef excludedResultIndices) { + region.walk([&](IREE::Flow::ReturnOp terminator) { + llvm::SmallVector newReturns; + for (auto it : llvm::enumerate(terminator.getOperands())) { + if (!llvm::count(excludedResultIndices, it.index())) { + newReturns.push_back(it.value()); + } + } + terminator.getOperation()->setOperands(newReturns); + }); +} + +// Returns true if |constantOp| represents a (logically) small constant value. +// +// "Small" is relative and there's a risk that we'll bloat the closures by +// duplicating a bunch of constants however what we are able to save by not +// doing that usually wins. Think of the number of bytes used on instructions to +// allocate/place/copy, setup function call arguments, compute the address, +// dereference the memory, etc vs. a constant immediate value of 16 bytes - +// afterall, there are single x64 instructions that approach 15 bytes :) +// +// This is also still at a fairly high level (flow dialect): once the closures +// are expanded out in lower dialects things like CSE have a chance to once +// again get at the constants and dedupe them if they survive. +static bool isConstantSmall(ConstantOp constantOp) { + // We could tune this/take it as a configuration setting. + // The current value is chosen based on what is known to be reasonable to + // inline into command buffers way down in the HAL, which is not great but at + // least better than either allocating independent buffers for 4 byte + // constants or inlining megabytes. + static constexpr int kMaxInlinedConstantBytes = 256; + + auto constantValueAttr = constantOp.getValue(); + auto constantType = constantOp.getType(); + if (constantValueAttr.isa()) { + // Splats are always small and can often have special handling when we + // know they are a splat - which is why it's so important we inline them + // here so we know when they are used that's the case. + return true; + } else if (auto denseAttr = constantValueAttr.dyn_cast()) { + // Smallish constants are worth moving inside. + auto shapedType = constantType.cast(); + uint64_t estimatedByteLength = + (shapedType.getNumElements() * shapedType.getElementTypeBitWidth()) / 8; + return denseAttr.isSplat() || + estimatedByteLength <= kMaxInlinedConstantBytes; + } else if (constantType.isIntOrIndexOrFloat()) { + // Primitives can always go in. + return true; + } + + return false; +} + +// Returns true if the given value should be inlined into the closure region. +// This is non-recursive and only holds for this value. Recursively cloning +// trees is hard and it'd be better to model that differently such as by having +// a wrapper region for immutable blobs that can be inlined that this then +// returns true for. +static bool shouldInlineIntoClosure(Value value) { + auto definingOp = value.getDefiningOp(); + if (auto constantOp = dyn_cast(definingOp)) { + // Constants are perfect! + return isConstantSmall(constantOp); + } else if (auto variableLoadOp = + dyn_cast(definingOp)) { + // If the variable is immutable then we can inline the reference to it. + auto variableOp = + SymbolTable::lookupNearestSymbolFrom( + definingOp, variableLoadOp.variable()); + return !variableOp.is_mutable(); + } + return false; +} + +// Inlines operands of the closure into the entry block as appropriate. +// The closure operands and block arguments will remain untouched but all uses +// will be replaced with the newly cloned values. +// +// Note that if multiple operands reference the same value it will get cloned +// multiple times. That's fine, as anything we can inline here is something we +// should also be able to CSE and that happens later on anyway. +static void inlineClosureOperands(ClosureOpInterface &closureOp, + Block &entryBlock) { + auto builder = OpBuilder::atBlockBegin(&entryBlock); + for (auto opArg : llvm::enumerate(closureOp.getClosureOperands())) { + auto outerValue = opArg.value(); + auto *sourceOp = outerValue.getDefiningOp(); + if (!sourceOp) continue; // can't clone block arguments into closures + if (closureOp.canClosureContainOp(sourceOp) && + shouldInlineIntoClosure(outerValue)) { + // Clone the op (with regions). + auto *clonedOp = builder.clone(*sourceOp); + + // Ensure we are using the right result in the case of ops with multiple + // results. If we only end up using a single result then canonicalization + // should take care of removing the unneeded ones. + int resultIndex = + std::distance(sourceOp->result_begin(), + std::find(sourceOp->result_begin(), + sourceOp->result_end(), outerValue)); + auto newValue = clonedOp->getResult(resultIndex); + + // Replace all of the uses inside of the closure. + auto innerValue = entryBlock.getArgument(opArg.index()); + innerValue.replaceAllUsesWith(newValue); + } + } +} + +bool optimizeClosureLikeOp(ClosureOpInterface &closureOp, + PatternRewriter *rewriter) { + // NOTE: the block is transferred to the new op; we can update it in place. + Block &entryBlock = closureOp.getClosureBodyRegion().front(); + + // Find constants/metadata/etc that we can clone into the closure. + // By doing this first we potentially create some dead operands that we can + // then elide below. When we do inline things the operands will be changed + // such that the following work is guaranteed to happen and thus our op will + // be rebuilt. + inlineClosureOperands(closureOp, entryBlock); // Build data structure for unused operand elision. - for (auto it : llvm::enumerate(entryBlock.getArguments())) { - BlockArgument blockArg = it.value(); - Value opArg = closureOp->getOperand(it.index() + variadicOffset); - if (blockArg.getUses().empty()) { + SmallVector elidedOperands; + llvm::SmallMapVector argToBlockMap; + SmallVector, 8> blockArgReplacements( + entryBlock.getNumArguments()); + for (auto opArg : llvm::enumerate(closureOp.getClosureOperands())) { + auto blockArg = entryBlock.getArgument(opArg.index()); + if (blockArg.use_empty()) { // Not used - Drop. - needsOperandElision = true; - blockArgReplacements[it.index()] = BlockArgument(); + elidedOperands.push_back(opArg.index()); + blockArgReplacements[opArg.index()] = BlockArgument(); continue; } - auto existingIt = argToBlockMap.find(opArg); + auto existingIt = argToBlockMap.find(opArg.value()); if (existingIt == argToBlockMap.end()) { // Not found - Record for deduping. - argToBlockMap.insert(std::make_pair(opArg, blockArg)); + argToBlockMap.insert(std::make_pair(opArg.value(), blockArg)); } else { // Found - Replace. - needsOperandElision = true; - blockArgReplacements[it.index()] = existingIt->second; + elidedOperands.push_back(opArg.index()); + blockArgReplacements[opArg.index()] = existingIt->second; } } // Check for unused results. - for (auto result : closureOp->getResults()) { - if (result.getUses().empty()) { - needsResultElision = true; - break; + SmallVector preservedResults; + SmallVector elidedResults; + for (auto result : llvm::enumerate(closureOp.getClosureResults())) { + if (result.value().use_empty()) { + elidedResults.push_back(result.index()); + } else { + preservedResults.push_back(result.value()); } } -} -void ClosureOpDce::elideUnusedOperands(OpBuilder &builder) { - llvm::SmallVector newOperands( - closureOp->operand_begin(), closureOp->operand_begin() + variadicOffset); + if (elidedOperands.empty() && elidedResults.empty()) { + // No optimization required. + return false; + } + + if (elidedResults.size() == closureOp.getClosureResults().size()) { + // The op is completely unused - delete it. + if (rewriter) { + rewriter->eraseOp(closureOp); + } else { + closureOp.erase(); + } + closureOp = {}; + return true; + } + + // Replace duplicate block arguments. unsigned blockArgIndex = 0; - for (auto it : llvm::enumerate(blockArgReplacements)) { - llvm::Optional replacement = it.value(); - Value currentOpArg = closureOp->getOperand(it.index() + variadicOffset); + for (auto replacement : blockArgReplacements) { if (!replacement) { // No change. - newOperands.push_back(currentOpArg); blockArgIndex++; - continue; } else if (!replacement.getValue()) { - // Drop. - entryBlock.eraseArgument(blockArgIndex); - continue; + // Dropped. } else { - // Replace. - BlockArgument currentBlockArg = entryBlock.getArgument(blockArgIndex); - currentBlockArg.replaceAllUsesWith(*replacement); - entryBlock.eraseArgument(blockArgIndex); + // Replaced. + entryBlock.getArgument(blockArgIndex).replaceAllUsesWith(*replacement); } } - closureOp->setOperands(newOperands); -} - -void ClosureOpDce::elideUnusedResults(OpBuilder &builder, bool eraseOriginal) { - // Determine the result signature transform needed. - llvm::SmallVector resultIndexMap; - llvm::SmallVector newResultTypes; - for (auto it : llvm::enumerate(closureOp->getResults())) { - if (!it.value().getUses().empty()) { - newResultTypes.push_back(it.value().getType()); - resultIndexMap.push_back(it.index()); - } + // Clone the op with the elidable operands and results removed. + OpBuilder builder(closureOp); + auto newOp = closureOp.cloneReplacementExcludingOperandsAndResults( + elidedOperands, elidedResults); + if (rewriter) { + rewriter->insert(newOp); + } else { + builder.insert(newOp); } - // Re-allocate the op. - builder.setInsertionPoint(closureOp); - Operation *newOp = - builder.insert(cloneWithNewResultTypes(closureOp, newResultTypes)); - - // Remap all returns. - llvm::SmallVector newReturns(resultIndexMap.size()); - newOp->walk([&](IREE::Flow::ReturnOp terminator) { - for (unsigned i = 0, e = resultIndexMap.size(); i < e; ++i) { - newReturns[i] = terminator.getOperand(resultIndexMap[i]); - } - terminator.getOperation()->setOperands(newReturns); - }); + // Replace original uses of the closure results. + for (auto oldNewResult : + llvm::zip(preservedResults, newOp.getClosureResults())) { + std::get<0>(oldNewResult).replaceAllUsesWith(std::get<1>(oldNewResult)); + } - // Replace original uses. - for (unsigned i = 0, e = resultIndexMap.size(); i < e; ++i) { - closureOp->getResult(resultIndexMap[i]) - .replaceAllUsesWith(newOp->getResult(i)); + // Erase the original op. + if (rewriter) { + rewriter->eraseOp(closureOp); + } else { + closureOp.erase(); } - if (eraseOriginal) closureOp->erase(); + closureOp = newOp; + return true; } } // namespace Flow diff --git a/iree/compiler/Dialect/Flow/IR/FlowOpUtils.h b/iree/compiler/Dialect/Flow/IR/FlowOpUtils.h index 38b44d0be86a..2e2da8ae1dd3 100644 --- a/iree/compiler/Dialect/Flow/IR/FlowOpUtils.h +++ b/iree/compiler/Dialect/Flow/IR/FlowOpUtils.h @@ -12,56 +12,69 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h" +#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallVector.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" namespace mlir { namespace iree_compiler { namespace IREE { namespace Flow { -// Clones an operation with new result types. -// The original operation will be erased and a new operation constructed -// in its place. -Operation *cloneWithNewResultTypes(Operation *op, TypeRange newResultTypes); +//------------------------------------------------------------------------------ +// Closure optimization +//------------------------------------------------------------------------------ -// Utility class to optimize a "closure" op, which maintains a variadic -// list of operands corresponding to entry block arguments. -class ClosureOpDce { - public: - ClosureOpDce(Operation *closureOp, Block &entryBlock, - unsigned variadicOffset); +// Modifies in-place the operand results vectors for a closure operation. +// |excludedOperandIndices| and |excludedResultIndices| are sets containing the +// operands and results in the lists to remove. +void excludeClosureOperandsAndResults(SmallVector &operandValues, + ArrayRef excludedOperandIndices, + SmallVector &resultTypes, + ArrayRef excludedResultIndices); +void excludeClosureOperandsAndResults(SmallVector &operandValues, + SmallVector &operandDims, + ArrayRef excludedOperandIndices, + SmallVector &resultTypes, + SmallVector &resultDims, + ArrayRef excludedResultIndices); - bool needsOptimization() { return needsOperandElision || needsResultElision; } +// Erases the given result indices from terminators in the given region. +void eraseRegionResults(Region ®ion, + ArrayRef excludedResultIndices); - // Whether the operation needs to be replaced. - bool needsNewOperation() { return needsResultElision; } +// Optimizes closure |closureOp| to remove duplicate operands and unused +// results. The op may be mutated, destroyed, or replaced with a new one. If an +// optional |rewriter| is provided then it will be notified of the operations +// performed on the op. Returns true if the op was optimized. +bool optimizeClosureLikeOp(ClosureOpInterface &closureOp, + PatternRewriter *rewriter = nullptr); +template +inline bool optimizeClosureOp(T &op, PatternRewriter *rewriter = nullptr) { + auto closureOp = cast(op.getOperation()); + bool didOptimize = optimizeClosureLikeOp(closureOp, rewriter); + op = dyn_cast_or_null(closureOp.getOperation()); + return didOptimize; +} - // Performs the optimization. If the optional eraseOriginal=false and - // needsNewOperation(), then the original will not be erased, leaving that - // to the caller (which is needed in some pattern rewriting scenarios). - // TODO(laurenzo): Fix OpBuilder upstream so that this eraseOriginal - // workaround is not required to write a safe rewriter pattern that uses this - // utility. - Operation *optimize(OpBuilder &builder, bool eraseOriginal = true) { - if (needsResultElision) elideUnusedResults(builder, eraseOriginal); - if (needsOperandElision) elideUnusedOperands(builder); - return closureOp; - } - - private: - void elideUnusedOperands(OpBuilder &builder); - void elideUnusedResults(OpBuilder &builder, bool eraseOriginal); +// A pattern that optimizes the given region-containing op T (CSE, DCE, etc). +// Duplicate operands will be combined and unused operands and results will be +// removed. +// +// T must implement the IREE::Flow::ClosureOpInterface. +template +struct ClosureOptimizationPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - Operation *closureOp; - Block &entryBlock; - unsigned variadicOffset; - llvm::SmallVector, 8> blockArgReplacements; - llvm::SmallMapVector argToBlockMap; - bool needsOperandElision = false; - bool needsResultElision = false; + LogicalResult matchAndRewrite(T op, + PatternRewriter &rewriter) const override { + auto closureOp = cast(op.getOperation()); + return optimizeClosureLikeOp(closureOp, &rewriter) ? success() : failure(); + } }; } // namespace Flow diff --git a/iree/compiler/Dialect/Flow/IR/FlowOps.cpp b/iree/compiler/Dialect/Flow/IR/FlowOps.cpp index 3bca4d5a86d7..4fff767d3b8a 100644 --- a/iree/compiler/Dialect/Flow/IR/FlowOps.cpp +++ b/iree/compiler/Dialect/Flow/IR/FlowOps.cpp @@ -16,9 +16,9 @@ #include "iree/compiler/Dialect/Flow/IR/FlowOpUtils.h" #include "iree/compiler/Dialect/IREE/IR/IREETypes.h" +#include "iree/compiler/Dialect/Shape/IR/Builders.h" #include "llvm/ADT/StringExtras.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" @@ -37,6 +37,10 @@ namespace iree_compiler { namespace IREE { namespace Flow { +//===----------------------------------------------------------------------===// +// Op utilities used within the Flow dialect +//===----------------------------------------------------------------------===// + // Returns true if the given |accessType| is compatible with the |variableType|. // For example, this will return true if the variable type is a tensor // and the access is tensor<4xf32>. @@ -44,6 +48,107 @@ static bool isVariableTypeCompatible(Type variableType, Type accessType) { return succeeded(mlir::verifyCompatibleShape(variableType, accessType)); } +// Verifies that |dynamicDims| contains the appropriate number of dims for all +// of the dynamic dimensions in |values|. +static LogicalResult verifyOpDynamicDims(Operation *op, ValueRange values, + ValueRange dynamicDims) { + unsigned requiredCount = 0; + for (auto value : values) { + if (auto shapedType = value.getType().dyn_cast()) { + requiredCount += shapedType.getNumDynamicDims(); + } + } + if (dynamicDims.size() != requiredCount) { + return op->emitOpError() + << "value set has " << requiredCount + << " dynamic dimensions but only " << dynamicDims.size() + << " dimension values are attached"; + } + return success(); +} + +//===----------------------------------------------------------------------===// +// custom +//===----------------------------------------------------------------------===// +// (type, type{%dim0, %dim1}, type) -> type{%dim2} + +static ParseResult parseShapedTypeList( + OpAsmParser &parser, SmallVectorImpl &types, + SmallVectorImpl &dims) { + do { + Type type; + if (failed(parser.parseType(type))) return failure(); + if (auto shapedType = type.dyn_cast()) { + if (!shapedType.hasStaticShape()) { + if (failed(parser.parseLBrace()) || + failed(parser.parseOperandList(dims, shapedType.getNumDynamicDims(), + OpAsmParser::Delimiter::None)) || + failed(parser.parseRBrace())) { + return failure(); + } + } + } + types.push_back(type); + } while (succeeded(parser.parseOptionalComma())); + return success(); +} + +static void printShapedTypeList(OpAsmPrinter &p, Operation *op, TypeRange types, + OperandRange dims) { + llvm::interleaveComma(types, p, [&](Type type) { + p.printType(type); + if (auto shapedType = type.dyn_cast()) { + if (!shapedType.hasStaticShape()) { + if (dims.empty()) { + p << "{<>}"; + return; + } + p << "{"; + llvm::interleaveComma(dims.take_front(shapedType.getNumDynamicDims()), + p, [&](Value value) { p.printOperand(value); }); + p << "}"; + dims = dims.drop_front(shapedType.getNumDynamicDims()); + } + } + }); +} + +static ParseResult parseShapedFunctionType( + OpAsmParser &parser, SmallVectorImpl &operandTypes, + SmallVectorImpl &operandDims, + SmallVectorImpl &resultTypes, + SmallVectorImpl &resultDims) { + if (failed(parser.parseLParen())) return failure(); + if (failed(parser.parseOptionalRParen())) { + if (failed(parseShapedTypeList(parser, operandTypes, operandDims)) || + failed(parser.parseRParen())) { + return failure(); + } + } + if (failed(parser.parseArrow())) return failure(); + if (succeeded(parser.parseOptionalLParen())) { + if (failed(parseShapedTypeList(parser, resultTypes, resultDims)) || + failed(parser.parseRParen())) { + return failure(); + } + } else if (failed(parseShapedTypeList(parser, resultTypes, resultDims))) { + return failure(); + } + return success(); +} + +static void printShapedFunctionType(OpAsmPrinter &p, Operation *op, + TypeRange operandTypes, + OperandRange operandDims, + TypeRange resultTypes, + OperandRange resultDims) { + p << "("; + printShapedTypeList(p, op, operandTypes, operandDims); + p << ") -> ("; + printShapedTypeList(p, op, resultTypes, resultDims); + p << ")"; +} + //===----------------------------------------------------------------------===// // flow.variable //===----------------------------------------------------------------------===// @@ -329,10 +434,24 @@ DispatchRegionOp::formFromAnchorOp(Value workload, Operation *anchorOp, return std::make_pair(drOp, newAnchorOp); } -void DispatchRegionOp::dceOperandsAndResults(DispatchRegionOp &op) { - OpBuilder builder(op.getContext()); - ClosureOpDce dce(op, op.body().front(), /*variadicOffset=*/1); - op = llvm::cast(dce.optimize(builder)); +// Clones an operation with new result types. +// The original operation will be erased and a new operation constructed +// in its place. +static Operation *cloneWithNewResultTypes(Operation *op, + TypeRange newResultTypes) { + OperationState state(op->getLoc(), op->getName()); + state.addOperands(op->getOperands()); + state.addTypes(newResultTypes); + state.addSuccessors(op->getSuccessors()); + state.addAttributes(op->getAttrs()); + for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) { + state.addRegion(); + } + Operation *newOp = Operation::create(state); + for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) { + newOp->getRegion(i).takeBody(op->getRegion(i)); + } + return newOp; } ResultRange DispatchRegionOp::appendResults(DispatchRegionOp &self, @@ -489,10 +608,9 @@ void printDispatchRegionOp(OpAsmPrinter &p, DispatchRegionOp op) { // Print the result types, if any. if (op.getNumResults() > 0) { - p << " -> "; - if (op.getNumResults() > 1) p << "("; + p << " -> ("; interleaveComma(op.getResultTypes(), p); - if (op.getNumResults() > 1) p << ")"; + p << ")"; } p.printRegion(op.body(), /*printEntryBlockArgs=*/false); @@ -500,22 +618,77 @@ void printDispatchRegionOp(OpAsmPrinter &p, DispatchRegionOp op) { /*elidedAttrs=*/{}); } +Operation::operand_range DispatchRegionOp::getClosureOperands() { + return args(); +} + +Operation::result_range DispatchRegionOp::getClosureResults() { + return results(); +} + +// TODO(#4897): allow non-splat constants - current paths can't handle them. +static bool canDispatchRegionContainOpIssue4897(Operation *op) { + if (auto constantOp = dyn_cast(op)) { + auto constantValueAttr = constantOp.getValue(); + auto constantType = constantOp.getType(); + if (constantValueAttr.isa()) { + return true; + } else if (auto denseAttr = + constantValueAttr.dyn_cast()) { + return denseAttr.isSplat(); + } else if (constantType.isIntOrIndexOrFloat()) { + return true; + } + } + return false; +} + +bool DispatchRegionOp::canClosureContainOp(Operation *op) { + return canDispatchRegionContainOpIssue4897(op); +} + +ClosureOpInterface +DispatchRegionOp::cloneReplacementExcludingOperandsAndResults( + ArrayRef excludedOperandIndices, + ArrayRef excludedResultIndices) { + SmallVector newResultTypes = llvm::to_vector<4>(getResultTypes()); + SmallVector newOperandsValues = llvm::to_vector<4>(args()); + excludeClosureOperandsAndResults(newOperandsValues, excludedOperandIndices, + newResultTypes, excludedResultIndices); + auto newOp = OpBuilder(getContext()) + .create(getLoc(), newResultTypes, + workload(), newOperandsValues, + getOperation()->getAttrs()); + auto &newBody = newOp.getClosureBodyRegion(); + newBody.takeBody(getClosureBodyRegion()); + eraseRegionResults(newBody, excludedResultIndices); + newBody.front().eraseArguments(excludedOperandIndices); + return newOp; +} + //===----------------------------------------------------------------------===// // flow.dispatch.workgroups //===----------------------------------------------------------------------===// void DispatchWorkgroupsOp::build(OpBuilder &builder, OperationState &state, ValueRange workgroupCount, - TypeRange resultTypes, ValueRange operands, + TypeRange resultTypes, ValueRange resultDims, + ValueRange operands, ValueRange operandDims, ArrayRef attributes) { state.addTypes(resultTypes); state.addOperands(workgroupCount); state.addOperands(operands); + state.addOperands(operandDims); + state.addOperands(resultDims); state.addAttributes(attributes); - state.addAttribute( - "operand_segment_sizes", - builder.getI32VectorAttr({static_cast(workgroupCount.size()), - static_cast(operands.size())})); + state.attributes.erase("operand_segment_sizes"); + state.addAttribute("operand_segment_sizes", + builder.getI32VectorAttr({ + static_cast(workgroupCount.size()), + static_cast(operands.size()), + static_cast(operandDims.size()), + static_cast(resultDims.size()), + })); auto *body = state.addRegion(); assert(body->begin() == body->end()); @@ -581,7 +754,7 @@ static void printDispatchWorkgroupBody(OpAsmPrinter &p, Operation *op, p << "("; interleaveComma(body.getArguments(), p, [&](BlockArgument arg) { p << arg; - p << " : "; + p << ": "; p << arg.getType(); }); p << ")"; @@ -589,93 +762,67 @@ static void printDispatchWorkgroupBody(OpAsmPrinter &p, Operation *op, /*printBlockTerminators=*/true); } -// TODO(benvanik): remove after https://bugs.llvm.org/show_bug.cgi?id=48478 -// The parser/printer are modified autogenerated values to work around the bug. - -static ::mlir::ParseResult parseDispatchWorkgroupsOp( - ::mlir::OpAsmParser &parser, ::mlir::OperationState *result) { - ::mlir::SmallVector<::mlir::OpAsmParser::OperandType, 4> - workgroup_countOperands; - ::llvm::SMLoc workgroup_countOperandsLoc; - (void)workgroup_countOperandsLoc; - ::mlir::SmallVector<::mlir::OpAsmParser::OperandType, 4> operandsOperands; - ::llvm::SMLoc operandsOperandsLoc; - (void)operandsOperandsLoc; - ::llvm::ArrayRef<::mlir::Type> operandsTypes; - ::llvm::ArrayRef<::mlir::Type> resultsTypes; - std::unique_ptr<::mlir::Region> bodyRegion = - std::make_unique<::mlir::Region>(); - if (parser.parseLSquare()) return ::mlir::failure(); - - workgroup_countOperandsLoc = parser.getCurrentLocation(); - if (parser.parseOperandList(workgroup_countOperands)) - return ::mlir::failure(); - if (parser.parseRSquare()) return ::mlir::failure(); - if (parser.parseLParen()) return ::mlir::failure(); - - operandsOperandsLoc = parser.getCurrentLocation(); - if (parser.parseOperandList(operandsOperands)) return ::mlir::failure(); - if (parser.parseRParen()) return ::mlir::failure(); - if (parser.parseColon()) return ::mlir::failure(); - - ::mlir::FunctionType operands__results_functionType; - if (parser.parseType(operands__results_functionType)) - return ::mlir::failure(); - operandsTypes = operands__results_functionType.getInputs(); - resultsTypes = operands__results_functionType.getResults(); - if (parser.parseOptionalAttrDictWithKeyword(result->attributes)) - return ::mlir::failure(); - if (parser.parseEqual()) return ::mlir::failure(); - { - if (parseDispatchWorkgroupBody(parser, operandsTypes, resultsTypes, - *bodyRegion)) - return ::mlir::failure(); - } - ::mlir::Type odsBuildableType0 = parser.getBuilder().getIndexType(); - result->addTypes(resultsTypes); - if (parser.resolveOperands(workgroup_countOperands, odsBuildableType0, - result->operands)) - return ::mlir::failure(); - if (parser.resolveOperands(operandsOperands, operandsTypes, - operandsOperandsLoc, result->operands)) - return ::mlir::failure(); - result->addRegion(std::move(bodyRegion)); - result->addAttribute( - "operand_segment_sizes", - parser.getBuilder().getI32VectorAttr( - {static_cast(workgroup_countOperands.size()), - static_cast(operandsOperands.size())})); - return ::mlir::success(); -} - -static void printDispatchWorkgroupsOp(::mlir::OpAsmPrinter &p, - DispatchWorkgroupsOp &op) { - p << "flow.dispatch.workgroups"; - p << "["; - p << op.workgroup_count(); - p << "]"; - p << ' ' << "("; - p << op.operands(); - p << ")"; - p << ' ' << ":"; - p << ' '; - p.printFunctionalType(op.operands().getTypes(), op.results().getTypes()); - p.printOptionalAttrDictWithKeyword(op->getAttrs(), /*elidedAttrs=*/{ - "operand_segment_sizes", - }); - p << ' ' << "="; - p << ' '; - printDispatchWorkgroupBody(p, op, op.operands().getTypes(), - op.results().getTypes(), op.body()); -} - static LogicalResult verifyDispatchWorkgroupsOp(DispatchWorkgroupsOp op) { if (op.workgroup_count().empty()) { return op.emitOpError() << "at least one workgroup dimension is required"; } + if (failed(verifyOpDynamicDims(op, op.operands(), op.operand_dims())) || + failed(verifyOpDynamicDims(op, op.results(), op.result_dims()))) { + return failure(); + } return success(); } +Value DispatchWorkgroupsOp::buildOperandRankedShape(unsigned idx, + OpBuilder &builder) { + return Shape::buildRankedShapeForValueInList(getLoc(), idx, getOperands(), + operand_dims(), builder); +} + +Value DispatchWorkgroupsOp::buildResultRankedShape(unsigned idx, + OpBuilder &builder) { + return Shape::buildRankedShapeForValueInList(getLoc(), idx, getResults(), + result_dims(), builder); +} + +Operation::operand_range DispatchWorkgroupsOp::getClosureOperands() { + return operands(); +} + +Operation::result_range DispatchWorkgroupsOp::getClosureResults() { + return results(); +} + +bool DispatchWorkgroupsOp::canClosureContainOp(Operation *op) { + return canDispatchRegionContainOpIssue4897(op); +} + +ClosureOpInterface +DispatchWorkgroupsOp::cloneReplacementExcludingOperandsAndResults( + ArrayRef excludedOperandIndices, + ArrayRef excludedResultIndices) { + SmallVector newResultTypes = llvm::to_vector<4>(getResultTypes()); + SmallVector newResultDims = llvm::to_vector<4>(result_dims()); + SmallVector newOperandsValues = llvm::to_vector<4>(operands()); + SmallVector newOperandDims = llvm::to_vector<4>(operand_dims()); + excludeClosureOperandsAndResults(newOperandsValues, newOperandDims, + excludedOperandIndices, newResultTypes, + newResultDims, excludedResultIndices); + auto newOp = + OpBuilder(getContext()) + .create( + getLoc(), workgroup_count(), newResultTypes, newResultDims, + newOperandsValues, newOperandDims, getOperation()->getAttrs()); + auto &newBody = newOp.getClosureBodyRegion(); + newBody.takeBody(getClosureBodyRegion()); + newBody.front().eraseArguments(excludedOperandIndices); + unsigned baseResultIndex = newBody.front().getNumArguments(); + newBody.front().eraseArguments(llvm::to_vector<4>(llvm::map_range( + excludedResultIndices, + [&](unsigned index) { return baseResultIndex + index; }))); + return newOp; +} + //===----------------------------------------------------------------------===// // flow.dispatch.workgroup.* //===----------------------------------------------------------------------===// @@ -842,7 +989,8 @@ static void printDispatchEntryOp(OpAsmPrinter &p, DispatchEntryOp op) { void DispatchOp::build(OpBuilder &builder, OperationState &state, DispatchEntryOp entryPoint, ValueRange workgroupCount, - TypeRange results, ValueRange operands, + TypeRange resultTypes, ValueRange resultDims, + ValueRange operands, ValueRange operandDims, ArrayRef attributes) { StringRef executableOpSymName = entryPoint->getParentOp() @@ -854,13 +1002,19 @@ void DispatchOp::build(OpBuilder &builder, OperationState &state, {builder.getSymbolRefAttr(entryPoint)})); state.addOperands(workgroupCount); - state.addTypes(results); + state.addTypes(resultTypes); state.addOperands(operands); + state.addOperands(operandDims); + state.addOperands(resultDims); state.addAttributes(attributes); - state.addAttribute( - "operand_segment_sizes", - builder.getI32VectorAttr({static_cast(workgroupCount.size()), - static_cast(operands.size())})); + state.attributes.erase("operand_segment_sizes"); + state.addAttribute("operand_segment_sizes", + builder.getI32VectorAttr({ + static_cast(workgroupCount.size()), + static_cast(operands.size()), + static_cast(operandDims.size()), + static_cast(resultDims.size()), + })); } StringRef DispatchOp::executable() { return entry_point().getRootReference(); } @@ -874,121 +1028,266 @@ static LogicalResult verifyDispatchOp(DispatchOp op) { if (op.workgroup_count().empty()) { return op.emitOpError() << "at least one workgroup dimension is required"; } + if (failed(verifyOpDynamicDims(op, op.operands(), op.operand_dims())) || + failed(verifyOpDynamicDims(op, op.results(), op.result_dims()))) { + return failure(); + } return success(); } +Value DispatchOp::buildOperandRankedShape(unsigned idx, OpBuilder &builder) { + return Shape::buildRankedShapeForValueInList(getLoc(), idx, getOperands(), + operand_dims(), builder); +} + +Value DispatchOp::buildResultRankedShape(unsigned idx, OpBuilder &builder) { + return Shape::buildRankedShapeForValueInList(getLoc(), idx, getResults(), + result_dims(), builder); +} + +//===----------------------------------------------------------------------===// +// flow.tensor.* +//===----------------------------------------------------------------------===// + +Value TensorReshapeOp::buildOperandRankedShape(unsigned idx, + OpBuilder &builder) { + return Shape::buildRankedShapeForValue(getLoc(), source(), source_dims(), + builder); +} + +Value TensorReshapeOp::buildResultRankedShape(unsigned idx, + OpBuilder &builder) { + return Shape::buildRankedShapeForValue(getLoc(), result(), result_dims(), + builder); +} + +Value TensorLoadOp::buildOperandRankedShape(unsigned idx, OpBuilder &builder) { + return Shape::buildRankedShapeForValue(getLoc(), source(), source_dims(), + builder); +} + +Value TensorLoadOp::buildResultRankedShape(unsigned idx, OpBuilder &builder) { + return {}; +} + +Value TensorStoreOp::buildOperandRankedShape(unsigned idx, OpBuilder &builder) { + return Shape::buildRankedShapeForValue(getLoc(), target(), target_dims(), + builder); +} + +Value TensorStoreOp::buildResultRankedShape(unsigned idx, OpBuilder &builder) { + return Shape::buildRankedShapeForValue(getLoc(), result(), target_dims(), + builder); +} + +Value TensorSplatOp::buildOperandRankedShape(unsigned idx, OpBuilder &builder) { + return {}; +} + +Value TensorSplatOp::buildResultRankedShape(unsigned idx, OpBuilder &builder) { + return Shape::buildRankedShapeForValue(getLoc(), result(), result_dims(), + builder); +} + +Value TensorCloneOp::buildOperandRankedShape(unsigned idx, OpBuilder &builder) { + return Shape::buildRankedShapeForValue(getLoc(), operand(), operand_dims(), + builder); +} + +Value TensorCloneOp::buildResultRankedShape(unsigned idx, OpBuilder &builder) { + return Shape::buildRankedShapeForValue(getLoc(), result(), operand_dims(), + builder); +} + +Value TensorSliceOp::buildOperandRankedShape(unsigned idx, OpBuilder &builder) { + return Shape::buildRankedShapeForValue(getLoc(), source(), source_dims(), + builder); +} + +Value TensorSliceOp::buildResultRankedShape(unsigned idx, OpBuilder &builder) { + return Shape::buildRankedShapeForValue(getLoc(), result(), result_dims(), + builder); +} + +//===----------------------------------------------------------------------===// +// flow.tensor.update +//===----------------------------------------------------------------------===// + +void TensorUpdateOp::build(OpBuilder &builder, OperationState &state, + Value target, ValueRange startIndices, + Value update) { + auto targetDims = + Shape::buildOrFindDynamicDimsForValue(state.location, target, builder); + auto updateDims = + Shape::buildOrFindDynamicDimsForValue(state.location, update, builder); + build(builder, state, target.getType(), target, targetDims, startIndices, + update, updateDims); +} + +static LogicalResult verifyTensorUpdateOp(TensorUpdateOp op) { + if (failed(verifyOpDynamicDims(op, {op.update()}, op.update_dims())) || + failed(verifyOpDynamicDims(op, {op.target()}, op.target_dims()))) { + return failure(); + } + return success(); +} + +Value TensorUpdateOp::buildOperandRankedShape(unsigned idx, + OpBuilder &builder) { + switch (idx) { + case 0: + return Shape::buildRankedShapeForValue(getLoc(), update(), update_dims(), + builder); + case 2: + return Shape::buildRankedShapeForValue(getLoc(), target(), target_dims(), + builder); + default: + llvm_unreachable("unshaped operand"); + } +} + +Value TensorUpdateOp::buildResultRankedShape(unsigned idx, OpBuilder &builder) { + return Shape::buildRankedShapeForValue(getLoc(), target(), target_dims(), + builder); +} + //===----------------------------------------------------------------------===// // flow.ex.stream.fragment //===----------------------------------------------------------------------===// void ExStreamFragmentOp::build(OpBuilder &builder, OperationState &state, - ArrayRef resultTypes, ValueRange operands, + TypeRange resultTypes, ValueRange resultDims, + ValueRange operands, ValueRange operandDims, ArrayRef attributes) { state.addTypes(resultTypes); state.addOperands(operands); + state.addOperands(operandDims); + state.addOperands(resultDims); state.addAttributes(attributes); + state.attributes.erase("operand_segment_sizes"); + state.addAttribute("operand_segment_sizes", + builder.getI32VectorAttr({ + static_cast(operands.size()), + static_cast(operandDims.size()), + static_cast(resultDims.size()), + })); state.addRegion(); } -ParseResult parseExStreamFragmentOp(OpAsmParser &parser, - OperationState *result) { +static LogicalResult verifyExStreamFragmentOp(ExStreamFragmentOp op) { + if (failed(verifyOpDynamicDims(op, op.operands(), op.operand_dims())) || + failed(verifyOpDynamicDims(op, op.results(), op.result_dims()))) { + return failure(); + } + return success(); +} + +static ParseResult parseStreamFragmentBody(OpAsmParser &parser, + TypeRange operandTypes, + TypeRange resultTypes, + Region &body) { + auto loc = parser.getCurrentLocation(); + SmallVector regionArgs; SmallVector regionArgTypes; if (failed(parser.parseLParen())) { return failure(); } if (failed(parser.parseOptionalRParen())) { - SmallVector regionOperands; - auto argsLoc = parser.getCurrentLocation(); do { // Reserve entries in the lists. regionArgs.emplace_back(); - regionOperands.emplace_back(); regionArgTypes.emplace_back(); if (failed(parser.parseRegionArgument(regionArgs.back())) || - failed(parser.parseEqual()) || - failed(parser.parseOperand(regionOperands.back())) || failed(parser.parseColonType(regionArgTypes.back()))) { return failure(); } } while (succeeded(parser.parseOptionalComma())); - if (failed(parser.parseRParen()) || - failed(parser.resolveOperands(regionOperands, regionArgTypes, argsLoc, - result->operands))) { + if (failed(parser.parseRParen())) { return failure(); } } - // Parse (optional) results. - if (failed(parser.parseOptionalArrowTypeList(result->types))) { - return failure(); - } + SmallVector regionResultTypes; + if (failed(parser.parseArrowTypeList(regionResultTypes))) return failure(); - // Parse region body. - Region *body = result->addRegion(); - if (failed(parser.parseRegion(*body, regionArgs, regionArgTypes)) || - failed(parser.parseOptionalAttrDict(result->attributes))) { - return failure(); + if (regionArgs.size() != operandTypes.size()) { + return parser.emitError(loc, "region operand list mismatch"); + } + if (regionResultTypes.size() != resultTypes.size()) { + return parser.emitError(loc, "region result list mismatch"); } - return success(); -} -void printExStreamFragmentOp(OpAsmPrinter &p, ExStreamFragmentOp op) { - p << op.getOperationName(); + return parser.parseRegion(body, regionArgs, regionArgTypes, + /*enableNameShadowing=*/true); +} - // Print the data argument remapping. +static void printStreamFragmentBody(OpAsmPrinter &p, Operation *op, + TypeRange operandTypes, + TypeRange resultTypes, Region &body) { p << "("; - interleaveComma(llvm::zip(op.body().getArguments(), op.args()), p, - [&](std::tuple it) { - p << std::get<0>(it) << " = " << std::get<1>(it); - p << " : "; - p << std::get<1>(it).getType(); - }); + llvm::interleaveComma(body.getArguments(), p, [&](BlockArgument arg) { + p << arg; + p << ": "; + p << arg.getType(); + }); + p << ") -> ("; + llvm::interleaveComma(resultTypes, p, [&](Type type) { p.printType(type); }); p << ")"; + p.printRegion(body, /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/true); +} - // Print the result types, if any. - if (op.getNumResults() > 0) { - p << " -> "; - if (op.getNumResults() > 1) p << "("; - interleaveComma(op.getResultTypes(), p); - if (op.getNumResults() > 1) p << ")"; - } +Value ExStreamFragmentOp::buildOperandRankedShape(unsigned idx, + OpBuilder &builder) { + return Shape::buildRankedShapeForValueInList(getLoc(), idx, getOperands(), + operand_dims(), builder); +} - p.printRegion(op.body(), /*printEntryBlockArgs=*/false); - p.printOptionalAttrDict(op->getAttrs(), - /*elidedAttrs=*/{}); +Value ExStreamFragmentOp::buildResultRankedShape(unsigned idx, + OpBuilder &builder) { + return Shape::buildRankedShapeForValueInList(getLoc(), idx, getResults(), + result_dims(), builder); } -//===----------------------------------------------------------------------===// -// flow.tensor.update -//===----------------------------------------------------------------------===// +Operation::operand_range ExStreamFragmentOp::getClosureOperands() { + return operands(); +} + +Operation::result_range ExStreamFragmentOp::getClosureResults() { + return results(); +} + +bool ExStreamFragmentOp::canClosureContainOp(Operation *op) { + // NOTE: we widen support on new stream ops only - the legacy path isn't worth + // upgrading to support more. + if (auto constantOp = dyn_cast(op)) { + return constantOp.getType().isIntOrIndexOrFloat(); + } + return false; +} -namespace { -// When the target tensor is a result of a tensor.cast operation, the op needs -// to be updated to use the source of the cast as the target tensor. -struct FoldTensorUpdateOpWithCasts : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TensorUpdateOp updateOp, - PatternRewriter &rewriter) const override { - auto targetCastOp = updateOp.target().getDefiningOp(); - auto updateCastOp = updateOp.update().getDefiningOp(); - if (!targetCastOp && !updateCastOp) return failure(); - auto target = (targetCastOp ? targetCastOp.source() : updateOp.target()); - auto update = (updateCastOp ? updateCastOp.source() : updateOp.update()); - auto newOp = rewriter.create( - updateOp.getLoc(), target.getType(), update, target, - updateOp.start_indices()); - rewriter.replaceOpWithNewOp( - updateOp, updateOp.getResult().getType(), newOp.getResult()); - return success(); - } -}; -} // namespace - -void TensorUpdateOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); +ClosureOpInterface +ExStreamFragmentOp::cloneReplacementExcludingOperandsAndResults( + ArrayRef excludedOperandIndices, + ArrayRef excludedResultIndices) { + SmallVector newResultTypes = llvm::to_vector<4>(getResultTypes()); + SmallVector newResultDims = llvm::to_vector<4>(result_dims()); + SmallVector newOperandsValues = llvm::to_vector<4>(operands()); + SmallVector newOperandDims = llvm::to_vector<4>(operand_dims()); + excludeClosureOperandsAndResults(newOperandsValues, newOperandDims, + excludedOperandIndices, newResultTypes, + newResultDims, excludedResultIndices); + auto newOp = + OpBuilder(getContext()) + .create(getLoc(), newResultTypes, newResultDims, + newOperandsValues, newOperandDims, + getOperation()->getAttrs()); + auto &newBody = newOp.getClosureBodyRegion(); + newBody.takeBody(getClosureBodyRegion()); + eraseRegionResults(newBody, excludedResultIndices); + newBody.front().eraseArguments(excludedOperandIndices); + return newOp; } } // namespace Flow diff --git a/iree/compiler/Dialect/Flow/IR/FlowOps.h b/iree/compiler/Dialect/Flow/IR/FlowOps.h index 0fda24fd1844..fae825805725 100644 --- a/iree/compiler/Dialect/Flow/IR/FlowOps.h +++ b/iree/compiler/Dialect/Flow/IR/FlowOps.h @@ -20,6 +20,7 @@ #include "iree/compiler/Dialect/Flow/IR/FlowDialect.h" #include "iree/compiler/Dialect/Flow/IR/FlowTypes.h" #include "iree/compiler/Dialect/IREE/IR/IREETraits.h" +#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h" #include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinOps.h" diff --git a/iree/compiler/Dialect/Flow/IR/FlowOps.td b/iree/compiler/Dialect/Flow/IR/FlowOps.td index c6f254835923..a90c276d118b 100644 --- a/iree/compiler/Dialect/Flow/IR/FlowOps.td +++ b/iree/compiler/Dialect/Flow/IR/FlowOps.td @@ -16,6 +16,8 @@ #define IREE_DIALECT_FLOW_OPS include "iree/compiler/Dialect/Flow/IR/FlowBase.td" +include "iree/compiler/Dialect/Flow/IR/FlowInterfaces.td" +include "iree/compiler/Dialect/Shape/IR/ShapeInterfaces.td" include "mlir/IR/OpAsmInterface.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" @@ -166,6 +168,7 @@ def FLOW_VariableStoreIndirectOp : FLOW_Op<"variable.store.indirect"> { def FLOW_DispatchRegionOp : FLOW_PureOp<"dispatch.region", [ IsolatedFromAbove, + DeclareOpInterfaceMethods, ]> { let summary = [{partitioned region representing a dispatched workload}]; let description = [{ @@ -202,11 +205,6 @@ def FLOW_DispatchRegionOp : FLOW_PureOp<"dispatch.region", [ formFromAnchorOp(Value workload, Operation *anchorOp, OpBuilder &builder); - /// Performs an in-place DCE optimization on unused operands and results. - /// Note that this may or may not re-allocate the op. If so, the reference - /// will be updated. - static void dceOperandsAndResults(DispatchRegionOp &op); - // Appends results to the dispatch region. This will re-allocate the // DispatchRegionOp itself but preserve the contained body block. // Returns a ResultRange for the new dispatch region op's results @@ -241,7 +239,9 @@ def FLOW_DispatchRegionOp : FLOW_PureOp<"dispatch.region", [ def FLOW_DispatchWorkgroupsOp : FLOW_PureOp<"dispatch.workgroups", [ IsolatedFromAbove, AttrSizedOperandSegments, - SingleBlockImplicitTerminator<"IREE::Flow::ReturnOp"> + SingleBlockImplicitTerminator<"IREE::Flow::ReturnOp">, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, ]> { let summary = [{a dispatch of workgroups across an n-dimension grid}]; let description = [{ @@ -281,7 +281,9 @@ def FLOW_DispatchWorkgroupsOp : FLOW_PureOp<"dispatch.workgroups", [ let arguments = (ins Variadic:$workgroup_count, - Variadic:$operands + Variadic:$operands, + FLOW_ShapeDynamicDims:$operand_dims, + FLOW_ShapeDynamicDims:$result_dims ); let results = (outs Variadic:$results @@ -289,23 +291,24 @@ def FLOW_DispatchWorkgroupsOp : FLOW_PureOp<"dispatch.workgroups", [ let regions = (region AnyRegion:$body); - // TODO(benvanik): use after https://bugs.llvm.org/show_bug.cgi?id=48478 - // let assemblyFormat = [{ - // `[` $workgroup_count `]` - // `(` $operands `)` `:` - // functional-type($operands, $results) - // attr-dict-with-keyword - // `=` - // custom(type_ref($operands), - // type_ref($results), - // $body) - // }]; + let assemblyFormat = [{ + `[` $workgroup_count `]` `` + `(` $operands `)` `:` + custom(type($operands), $operand_dims, + type($results), $result_dims) + attr-dict-with-keyword + `=` `\n` ` ` ` ` ` ` + custom(ref(type($operands)), + ref(type($results)), + $body) + }]; let skipDefaultBuilders = 1; let builders = [ OpBuilder<(ins "ValueRange":$workgroupCount, - "TypeRange":$resultTypes, "ValueRange":$operands, + "TypeRange":$resultTypes, "ValueRange":$resultDims, + "ValueRange":$operands, "ValueRange":$operandDims, CArg<"ArrayRef", "{}">:$attributes)>, ]; @@ -652,6 +655,7 @@ def FLOW_DispatchEntryOp : FLOW_Op<"dispatch.entry", [ def FLOW_DispatchOp : FLOW_PureOp<"dispatch", [ AttrSizedOperandSegments, FLOW_StreamableOp, + DeclareOpInterfaceMethods, ]> { let summary = [{a dispatch of workgroups across an n-dimension grid}]; let description = [{ @@ -663,7 +667,9 @@ def FLOW_DispatchOp : FLOW_PureOp<"dispatch", [ let arguments = (ins Variadic:$workgroup_count, SymbolRefAttr:$entry_point, - Variadic:$operands + Variadic:$operands, + FLOW_ShapeDynamicDims:$operand_dims, + FLOW_ShapeDynamicDims:$result_dims ); let results = (outs Variadic:$results @@ -673,7 +679,8 @@ def FLOW_DispatchOp : FLOW_PureOp<"dispatch", [ let builders = [ OpBuilder<(ins "DispatchEntryOp":$entryPoint, "ValueRange":$workgroupCount, - "TypeRange":$resultTypes, CArg<"ValueRange", "{}">:$operands, + "TypeRange":$resultTypes, "ValueRange":$resultDims, + "ValueRange":$operands, "ValueRange":$operandDims, CArg<"ArrayRef", "{}">:$attributes)>, ]; @@ -686,9 +693,10 @@ def FLOW_DispatchOp : FLOW_PureOp<"dispatch", [ }]; let assemblyFormat = [{ - $entry_point `[` $workgroup_count `]` + $entry_point `[` $workgroup_count `]` `` `(` $operands `)` attr-dict `:` - functional-type($operands, $results) + custom(type($operands), $operand_dims, + type($results), $result_dims) }]; let verifier = [{ return verifyDispatchOp(*this); }]; @@ -701,6 +709,8 @@ def FLOW_DispatchOp : FLOW_PureOp<"dispatch", [ def FLOW_TensorReshapeOp : FLOW_PureOp<"tensor.reshape", [ FLOW_StreamableOp, AllElementTypesMatch<["source", "result"]>, + AttrSizedOperandSegments, + DeclareOpInterfaceMethods, ]> { let summary = [{reshapes a tensor}]; let description = [{ @@ -708,14 +718,32 @@ def FLOW_TensorReshapeOp : FLOW_PureOp<"tensor.reshape", [ }]; let arguments = (ins - FLOW_Tensor:$source - // TODO(benvanik): FLOW_Shape:$shape when supporting dynamic shapes. + FLOW_Tensor:$source, + FLOW_ShapeDynamicDims:$source_dims, + FLOW_ShapeDynamicDims:$result_dims ); let results = (outs FLOW_Tensor:$result ); - let assemblyFormat = "$source `:` type($source) `->` type($result) attr-dict"; + let assemblyFormat = [{ + $source `:` + type($source) (`{` $source_dims^ `}`)? `->` + type($result) (`{` $result_dims^ `}`)? + attr-dict-with-keyword + }]; + + let builders = [ + OpBuilder<(ins + "Type":$result_type, "Value":$source, "ValueRange":$target_dims), + [{ + build($_builder, $_state, + result_type, + source, + Shape::buildOrFindDynamicDimsForValue($_state.location, source, $_builder), + target_dims); + }]>, + ]; let extraClassDeclaration = [{ // StreamableOpInterface: @@ -730,6 +758,8 @@ def FLOW_TensorLoadOp : FLOW_PureOp<"tensor.load", [ TypesMatchWith<"value type matches element type of target operand", "source", "result", "$_self.cast().getElementType()">, + AttrSizedOperandSegments, + DeclareOpInterfaceMethods, ]> { let summary = [{loads a value from a tensor element}]; let description = [{ @@ -738,6 +768,7 @@ def FLOW_TensorLoadOp : FLOW_PureOp<"tensor.load", [ let arguments = (ins FLOW_Tensor:$source, + FLOW_ShapeDynamicDims:$source_dims, Variadic:$indices ); let results = (outs @@ -745,9 +776,23 @@ def FLOW_TensorLoadOp : FLOW_PureOp<"tensor.load", [ ); let assemblyFormat = [{ - $source (`[` $indices^ `]`)? `:` type($source) attr-dict-with-keyword + $source (`[` $indices^ `]`)? `:` + type($source) (`{` $source_dims^ `}`)? + attr-dict-with-keyword }]; + let builders = [ + OpBuilder<(ins + "Type":$result_type, "Value":$source, CArg<"ValueRange", "{}">:$indices), + [{ + build($_builder, $_state, + result_type, + source, + Shape::buildOrFindDynamicDimsForValue($_state.location, source, $_builder), + indices); + }]>, + ]; + // TODO(benvanik): canonicalize to slice+load if dims are known. let hasFolder = 1; } @@ -757,6 +802,8 @@ def FLOW_TensorStoreOp : FLOW_PureOp<"tensor.store", [ TypesMatchWith<"value type matches element type of target operand", "target", "value", "$_self.cast().getElementType()">, + AttrSizedOperandSegments, + DeclareOpInterfaceMethods, ]> { let summary = [{stores a value into a tensor element}]; let description = [{ @@ -766,6 +813,7 @@ def FLOW_TensorStoreOp : FLOW_PureOp<"tensor.store", [ let arguments = (ins AnyTypeOf<[FLOW_PrimitiveType, AnyVector]>:$value, FLOW_Tensor:$target, + FLOW_ShapeDynamicDims:$target_dims, Variadic:$indices ); let results = (outs @@ -773,10 +821,24 @@ def FLOW_TensorStoreOp : FLOW_PureOp<"tensor.store", [ ); let assemblyFormat = [{ - $value `,` $target (`[` $indices^ `]`)? `:` type($target) + $value `,` $target (`[` $indices^ `]`)? `:` + type($target) (`{` $target_dims^ `}`)? attr-dict-with-keyword }]; + let builders = [ + OpBuilder<(ins + "Value":$value, "Value":$target, CArg<"ValueRange", "{}">:$indices), + [{ + build($_builder, $_state, + target.getType(), + value, + target, + Shape::buildOrFindDynamicDimsForValue($_state.location, target, $_builder), + indices); + }]>, + ]; + let hasFolder = 1; } @@ -785,6 +847,7 @@ def FLOW_TensorSplatOp : FLOW_PureOp<"tensor.splat", [ TypesMatchWith<"value type matches element type of result", "result", "value", "$_self.cast().getElementType()">, + DeclareOpInterfaceMethods, ]> { let summary = [{splats a value into a shaped tensor}]; let description = [{ @@ -792,14 +855,17 @@ def FLOW_TensorSplatOp : FLOW_PureOp<"tensor.splat", [ }]; let arguments = (ins - FLOW_PrimitiveType:$value - // TODO(benvanik): FLOW_Shape:$shape when supporting dynamic shapes. + FLOW_PrimitiveType:$value, + FLOW_ShapeDynamicDims:$result_dims ); let results = (outs FLOW_Tensor:$result ); - let assemblyFormat = "$value `:` type($result) attr-dict-with-keyword"; + let assemblyFormat = [{ + $value `:` type($result) (`{` $result_dims^ `}`)? + attr-dict-with-keyword + }]; let extraClassDeclaration = [{ // StreamableOpInterface: @@ -812,7 +878,8 @@ def FLOW_TensorSplatOp : FLOW_PureOp<"tensor.splat", [ def FLOW_TensorCloneOp : FLOW_PureOp<"tensor.clone", [ FLOW_StreamableOp, - SameOperandsAndResultType, + AllTypesMatch<["operand", "result"]>, + DeclareOpInterfaceMethods, ]> { let summary = [{performs a full tensor clone operation}]; let description = [{ @@ -820,13 +887,27 @@ def FLOW_TensorCloneOp : FLOW_PureOp<"tensor.clone", [ }]; let arguments = (ins - FLOW_Tensor:$operand + FLOW_Tensor:$operand, + FLOW_ShapeDynamicDims:$operand_dims ); let results = (outs FLOW_Tensor:$result ); - let assemblyFormat = "$operand `:` type($result) attr-dict"; + let assemblyFormat = [{ + $operand `:` type($result) (`{` $operand_dims^ `}`)? + attr-dict-with-keyword + }]; + + let builders = [ + OpBuilder<(ins "Value":$operand), + [{ + build($_builder, $_state, + operand.getType(), + operand, + Shape::buildOrFindDynamicDimsForValue($_state.location, operand, $_builder)); + }]>, + ]; let extraClassDeclaration = [{ // StreamableOpInterface: @@ -841,7 +922,8 @@ def FLOW_TensorSliceOp : FLOW_PureOp<"tensor.slice", [ FLOW_StreamableOp, AllRanksMatch<["source", "result"]>, AllElementTypesMatch<["source", "result"]>, - SameVariadicOperandSize, + AttrSizedOperandSegments, + DeclareOpInterfaceMethods, ]> { let summary = [{slices out a subregion of a tensor}]; let description = [{ @@ -850,17 +932,20 @@ def FLOW_TensorSliceOp : FLOW_PureOp<"tensor.slice", [ let arguments = (ins FLOW_Tensor:$source, + FLOW_ShapeDynamicDims:$source_dims, Variadic:$start_indices, - Variadic:$lengths - // TODO(benvanik): strides. + Variadic:$lengths, + FLOW_ShapeDynamicDims:$result_dims ); let results = (outs FLOW_Tensor:$result ); let assemblyFormat = [{ - $source `[` $start_indices `for` $lengths `]` `:` type($source) `->` - type($result) attr-dict + $source `[` $start_indices `for` $lengths `]` `:` + type($source) (`{` $source_dims^ `}`)? `->` + type($result) (`{` $result_dims^ `}`)? + attr-dict-with-keyword }]; let extraClassDeclaration = [{ @@ -877,6 +962,8 @@ def FLOW_TensorUpdateOp : FLOW_PureOp<"tensor.update", [ AllRanksMatch<["update", "target", "result"]>, AllTypesMatch<["target", "result"]>, AllElementTypesMatch<["update", "target", "result"]>, + AttrSizedOperandSegments, + DeclareOpInterfaceMethods, ]> { let summary = [{updates a tensor with the contents of another tensor}]; let description = [{ @@ -885,25 +972,37 @@ def FLOW_TensorUpdateOp : FLOW_PureOp<"tensor.update", [ }]; let arguments = (ins - FLOW_Tensor:$update, FLOW_Tensor:$target, - Variadic:$start_indices + FLOW_ShapeDynamicDims:$target_dims, + Variadic:$start_indices, + FLOW_Tensor:$update, + FLOW_ShapeDynamicDims:$update_dims ); let results = (outs FLOW_Tensor:$result ); let assemblyFormat = [{ - $update `,` $target `[` $start_indices `]` `:` type($update) `->` - type($result) attr-dict + $update `,` $target `[` $start_indices `]` `:` + type($update) (`{` $update_dims^ `}`)? `->` + type($result) (`{` $target_dims^ `}`)? + attr-dict-with-keyword }]; + let builders = [ + OpBuilder<(ins + "Value":$target, + "ValueRange":$start_indices, + "Value":$update)>, + ]; + let extraClassDeclaration = [{ // StreamableOpInterface: bool isTransfer() { return true; } }]; - // TODO(benvanik): canonicalize contiguous updates/across slices. + let verifier = [{ return verifyTensorUpdateOp(*this); }]; + let hasCanonicalizer = 1; let hasFolder = 1; } @@ -932,6 +1031,9 @@ def FLOW_TensorTraceOp : FLOW_Op<"tensor.trace", []> { // TODO(benvanik): replace with real segmented stream ops. def FLOW_ExStreamFragmentOp : FLOW_PureOp<"ex.stream.fragment", [ IsolatedFromAbove, + AttrSizedOperandSegments, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, ]> { let summary = [{experimental op for defining formed stream regions}]; let description = [{ @@ -940,7 +1042,9 @@ def FLOW_ExStreamFragmentOp : FLOW_PureOp<"ex.stream.fragment", [ }]; let arguments = (ins - Variadic:$args + Variadic:$operands, + FLOW_ShapeDynamicDims:$operand_dims, + FLOW_ShapeDynamicDims:$result_dims ); let results = (outs Variadic:$results @@ -948,12 +1052,27 @@ def FLOW_ExStreamFragmentOp : FLOW_PureOp<"ex.stream.fragment", [ let regions = (region AnyRegion:$body); + let assemblyFormat = [{ + `(` $operands `)` `:` + custom(type($operands), $operand_dims, + type($results), $result_dims) + attr-dict-with-keyword + `=` `\n` ` ` ` ` ` ` + custom(ref(type($operands)), + ref(type($results)), + $body) + }]; + let skipDefaultBuilders = 1; let builders = [ - OpBuilder<(ins "ArrayRef":$resultTypes, "ValueRange":$args, + OpBuilder<(ins + "TypeRange":$resultTypes, "ValueRange":$resultDims, + "ValueRange":$operands, "ValueRange":$operandDims, CArg<"ArrayRef", "{}">:$attributes)>, ]; + let verifier = [{ return verifyExStreamFragmentOp(*this); }]; + let hasCanonicalizer = 1; } diff --git a/iree/compiler/Dialect/Flow/IR/test/dispatch_ops.mlir b/iree/compiler/Dialect/Flow/IR/test/dispatch_ops.mlir index 2325c63bfac5..8861704fca00 100644 --- a/iree/compiler/Dialect/Flow/IR/test/dispatch_ops.mlir +++ b/iree/compiler/Dialect/Flow/IR/test/dispatch_ops.mlir @@ -4,7 +4,7 @@ flow.executable @ex0 { module { - func @dispatch_fn(%arg0 : tensor<4xf32>) -> tensor<4xf32> { + func @dispatch_fn(%cst : index, %arg0 : tensor<4xf32>) -> tensor<4xf32> { return %arg0 : tensor<4xf32> } } @@ -15,7 +15,7 @@ flow.executable @ex0 { func @dispatch(%arg0 : tensor<4xf32>) -> tensor<4xf32> { // CHECK: %[[CST:.+]] = constant %cst = constant 4 : index - // CHECK: %0 = flow.dispatch @ex0::@dispatch_fn[%[[CST]]] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> - %0 = flow.dispatch @ex0::@dispatch_fn[%cst] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: %0 = flow.dispatch @ex0::@dispatch_fn[%[[CST]]](%[[CST]], %arg0) : (index, tensor<4xf32>) -> (tensor<4xf32>) + %0 = flow.dispatch @ex0::@dispatch_fn[%cst](%cst, %arg0) : (index, tensor<4xf32>) -> (tensor<4xf32>) return %0 : tensor<4xf32> } diff --git a/iree/compiler/Dialect/Flow/IR/test/dispatch_region_folding.mlir b/iree/compiler/Dialect/Flow/IR/test/dispatch_region_folding.mlir index 8ab0c0f869b5..fb678b8894ed 100644 --- a/iree/compiler/Dialect/Flow/IR/test/dispatch_region_folding.mlir +++ b/iree/compiler/Dialect/Flow/IR/test/dispatch_region_folding.mlir @@ -5,7 +5,7 @@ func @dceOperandsAndResults(%arg0 : tensor) -> (tensor) { // CHECK: %[[WORKLOAD:.+]] = constant 5 %workload = constant 5 : index // CHECK: %[[R0:.+]] = flow.dispatch.region[%[[WORKLOAD]] : index] - // CHECK-SAME: (%[[CA1:.+]] = %arg0 : tensor) -> tensor + // CHECK-SAME: (%[[CA1:.+]] = %arg0 : tensor) -> (tensor) // CHECK: %[[DR0:.+]] = addf %[[CA1]], %[[CA1]] // CHECK: flow.return %[[DR0]] : tensor %ret0, %ret1 = flow.dispatch.region[%workload : index]( diff --git a/iree/compiler/Dialect/Flow/IR/test/dispatch_regions.mlir b/iree/compiler/Dialect/Flow/IR/test/dispatch_regions.mlir index 8e7171ebf9ca..d06d20c4ddfb 100644 --- a/iree/compiler/Dialect/Flow/IR/test/dispatch_regions.mlir +++ b/iree/compiler/Dialect/Flow/IR/test/dispatch_regions.mlir @@ -37,7 +37,7 @@ func @multipleArgs(%arg0 : tensor, %arg1 : tensor) { // CHECK-LABEL: @singleResult func @singleResult(%arg0 : tensor) -> tensor { // CHECK-NEXT: %[[WORKLOAD:.+]] = "some.shape" - // CHECK-NEXT: %1 = flow.dispatch.region[%[[WORKLOAD]] : index](%arg1 = %arg0 : tensor) -> tensor { + // CHECK-NEXT: %1 = flow.dispatch.region[%[[WORKLOAD]] : index](%arg1 = %arg0 : tensor) -> (tensor) { // CHECK-NEXT: flow.return %arg1 : tensor // CHECK-NEXT: } %workload = "some.shape"(%arg0) : (tensor) -> index diff --git a/iree/compiler/Dialect/Flow/IR/test/dispatch_workgroups.mlir b/iree/compiler/Dialect/Flow/IR/test/dispatch_workgroups.mlir index a0ab78448d05..3f42360a34e3 100644 --- a/iree/compiler/Dialect/Flow/IR/test/dispatch_workgroups.mlir +++ b/iree/compiler/Dialect/Flow/IR/test/dispatch_workgroups.mlir @@ -6,19 +6,20 @@ func @complexWorkgroupsUsage( %arg0 : tensor, // CHECK-SAME: %[[ARG1:.+]]: index %arg1 : index) -> tensor<4x?xf32> { + %c128 = constant 128 : index // CHECK-DAG: %[[WORKGROUP_COUNT_X:.+]] = constant 100 %x = constant 100 : index // CHECK-DAG: %[[WORKGROUP_COUNT_Y:.+]] = constant 50 %y = constant 50 : index // CHECK: %[[OUTER_RET0:.+]] = flow.dispatch.workgroups[ // CHECK-SAME: %[[WORKGROUP_COUNT_X]], %[[WORKGROUP_COUNT_Y]] - // CHECK-SAME: ] (%[[ARG0]], %[[ARG1]]) - // CHECK-SAME: : (tensor, index) -> tensor<4x?xf32> = ( - %0 = flow.dispatch.workgroups[%x, %y](%arg0, %arg1) : (tensor, index) -> (tensor<4x?xf32>) = - // CHECK-SAME: %[[INNER_ARG0:.+]] : !flow.dispatch.input - // CHECK-SAME: %[[INNER_ARG1:.+]] : index - // CHECK-SAME: %[[INNER_RET0:.+]] : !flow.dispatch.output<4x?xf32> - (%arg0_capture : !flow.dispatch.input, %arg1_capture : index, %ret0 : !flow.dispatch.output<4x?xf32>) { + // CHECK-SAME: ](%[[ARG0]], %[[ARG1]]) + // CHECK-SAME: : (tensor{%c128}, index) -> (tensor<4x?xf32>{%c128}) = + %0 = flow.dispatch.workgroups[%x, %y](%arg0, %arg1) : (tensor{%c128}, index) -> tensor<4x?xf32>{%c128} = + // CHECK-NEXT: (%[[INNER_ARG0:.+]]: !flow.dispatch.input + // CHECK-SAME: %[[INNER_ARG1:.+]]: index + // CHECK-SAME: %[[INNER_RET0:.+]]: !flow.dispatch.output<4x?xf32>) { + (%arg0_capture: !flow.dispatch.input, %arg1_capture: index, %ret0: !flow.dispatch.output<4x?xf32>) { // Query symbolic workgroup info: diff --git a/iree/compiler/Dialect/Flow/IR/test/dispatch_workgroups_folding.mlir b/iree/compiler/Dialect/Flow/IR/test/dispatch_workgroups_folding.mlir index 697ea9eb64b2..906c5f672dda 100644 --- a/iree/compiler/Dialect/Flow/IR/test/dispatch_workgroups_folding.mlir +++ b/iree/compiler/Dialect/Flow/IR/test/dispatch_workgroups_folding.mlir @@ -2,14 +2,15 @@ // CHECK-LABEL: @workgroupStaticShapeDims func @workgroupStaticShapeDims(%arg0 : tensor) -> tensor<4x?xf32> { + %c128 = constant 128 : index %x = constant 100 : index %y = constant 50 : index // CHECK: flow.dispatch.workgroups - %0 = flow.dispatch.workgroups[%x, %y](%arg0) : (tensor) -> (tensor<4x?xf32>) = ( - // CHECK-SAME: = (%[[ARG0:.+]] : !flow.dispatch.input - %arg0_capture : !flow.dispatch.input, - // CHECK-SAME: %[[RET0:.+]] : !flow.dispatch.output<4x?xf32>) - %ret0 : !flow.dispatch.output<4x?xf32> + %0 = flow.dispatch.workgroups[%x, %y](%arg0) : (tensor{%c128}) -> tensor<4x?xf32>{%c128} = ( + // CHECK-NEXT: (%[[ARG0:.+]]: !flow.dispatch.input, + %arg0_capture: !flow.dispatch.input, + // CHECK-SAME: %[[RET0:.+]]: !flow.dispatch.output<4x?xf32>) + %ret0: !flow.dispatch.output<4x?xf32> ) { // CHECK: %[[DIM_4:.+]] = constant 4 : index @@ -38,12 +39,13 @@ func @workgroupStaticShapeDims(%arg0 : tensor) -> tensor<4x?xf32> { // CHECK-LABEL: @workgroupRankFolding func @workgroupRankFolding(%arg0 : tensor) -> tensor<4x?xf32> { + %c128 = constant 128 : index %x = constant 100 : index %y = constant 50 : index // CHECK: flow.dispatch.workgroups - %0 = flow.dispatch.workgroups[%x, %y](%arg0) : (tensor) -> (tensor<4x?xf32>) = ( - %arg0_capture : !flow.dispatch.input, - %ret0 : !flow.dispatch.output<4x?xf32> + %0 = flow.dispatch.workgroups[%x, %y](%arg0) : (tensor{%c128}) -> tensor<4x?xf32>{%c128} = ( + %arg0_capture: !flow.dispatch.input, + %ret0: !flow.dispatch.output<4x?xf32> ) { // CHECK: %[[RANK:.+]] = constant 2 : index %workgroup_rank = flow.dispatch.workgroup.rank : index diff --git a/iree/compiler/Dialect/Flow/IR/test/stream_folding.mlir b/iree/compiler/Dialect/Flow/IR/test/stream_folding.mlir index 028a583ab78b..e68847663585 100644 --- a/iree/compiler/Dialect/Flow/IR/test/stream_folding.mlir +++ b/iree/compiler/Dialect/Flow/IR/test/stream_folding.mlir @@ -2,70 +2,55 @@ // RUN: iree-opt -split-input-file -canonicalize %s | iree-opt -split-input-file | IreeFileCheck %s -flow.executable @dispatch_0 { - flow.dispatch.entry @rgn_dispatch_0 - module { - func @rgn_dispatch_0(%arg0: tensor<4xf32>) -> tensor<4xf32> { - %0 = mhlo.multiply %arg0, %arg0 : tensor<4xf32> - return %0 : tensor<4xf32> - } - } -} - -// CHECK-LABEL: func @fragmentDedupCaptures -// CHECK-SAME: %[[A0:.+]]: tensor<4xf32> -func @fragmentDedupCaptures(%arg0 : tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { - // CHECK: %[[CST:.+]] = constant 4 +// CHECK-LABEL: func @inlineConstant +func @inlineConstant() -> index { %cst = constant 4 : index - // Should dedup %cst in arg list. - // CHECK: flow.ex.stream.fragment(%arg1 = %[[CST]] : index, %arg2 = %[[A0]] : tensor<4xf32>) - %0:2 = flow.ex.stream.fragment(%arg1 = %cst : index, %arg2 = %cst : index, %arg3 = %arg0 : tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { - // Both referreants of the constant should use the deduped arg. - // CHECK: flow.dispatch @dispatch_0::@rgn_dispatch_0[%arg1] - // CHECK: flow.dispatch @dispatch_0::@rgn_dispatch_0[%arg1] - %1 = flow.dispatch @dispatch_0::@rgn_dispatch_0[%arg1] (%arg3) : (tensor<4xf32>) -> tensor<4xf32> - %2 = flow.dispatch @dispatch_0::@rgn_dispatch_0[%arg2] (%1) : (tensor<4xf32>) -> tensor<4xf32> - flow.return %2, %2 : tensor<4xf32>, tensor<4xf32> + // CHECK: flow.ex.stream.fragment() + %0 = flow.ex.stream.fragment(%cst) : (index) -> index = + (%arg0: index) -> index { + // CHECK: %[[C:.+]] = constant 4 : index + // CHECK-NEXT: return %[[C]] + flow.return %arg0 : index } - return %0#0, %0#1 : tensor<4xf32>, tensor<4xf32> + return %0 : index } // ----- + // CHECK-LABEL: func @removeUnusedCapture -func @removeUnusedCapture() -> (index) { - // CHECK: %[[CST:.+]] = constant 4 - %cst = constant 4 : index +// CHECK-SAME: (%[[ARG:.+]]: index) +func @removeUnusedCapture(%arg: index) -> index { %unused = constant 5 : index - // CHECK: flow.ex.stream.fragment(%arg0 = %[[CST]] : index) - %0 = flow.ex.stream.fragment(%arg0 = %cst : index, %arg1 = %unused : index) -> (index) { + // CHECK: flow.ex.stream.fragment(%[[ARG]]) + %0 = flow.ex.stream.fragment(%arg, %unused) : (index, index) -> index = + (%arg0: index, %arg1: index) -> index { flow.return %arg0 : index } return %0 : index } // ----- + // CHECK-LABEL: func @removeUnusedDupCapture -func @removeUnusedDupCapture() -> (index) { - // CHECK: %[[CST:.+]] = constant 4 - %cst = constant 4 : index - // CHECK: flow.ex.stream.fragment(%arg0 = %[[CST]] : index) - %0 = flow.ex.stream.fragment(%arg0 = %cst : index, %arg1 = %cst : index) -> (index) { +// CHECK-SAME: (%[[ARG:.+]]: index) +func @removeUnusedDupCapture(%arg: index) -> index { + // CHECK: flow.ex.stream.fragment(%[[ARG]]) + %0 = flow.ex.stream.fragment(%arg, %arg) : (index, index) -> index = + (%arg0: index, %arg1: index) -> index { flow.return %arg1 : index } return %0 : index } // ----- + // CHECK-LABEL: func @removeUnusedResult -func @removeUnusedResult() -> (index) { - // CHECK: %[[CST:.+]] = constant 4 - %cst = constant 4 : index - // Note that the unused result should also cascade to elide the newly - // unused operand. - // CHECK: flow.ex.stream.fragment(%arg0 = %[[CST]] : index) - // CHECK-SAME: -> index - %0:2 = flow.ex.stream.fragment(%arg0 = %cst : index, %arg1 = %cst : index) -> (index, index) { +// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: index) +func @removeUnusedResult(%arg0: index, %arg1: index) -> index { + // CHECK: flow.ex.stream.fragment(%[[ARG1]]) + %0:2 = flow.ex.stream.fragment(%arg0, %arg1) : (index, index) -> (index, index) = + (%arg0: index, %arg1: index) -> (index, index) { flow.return %arg1, %arg0 : index, index } - return %0 : index + return %0#0 : index } diff --git a/iree/compiler/Dialect/Flow/IR/test/stream_ops.mlir b/iree/compiler/Dialect/Flow/IR/test/stream_ops.mlir index 457ea939c208..1dab1a9f4331 100644 --- a/iree/compiler/Dialect/Flow/IR/test/stream_ops.mlir +++ b/iree/compiler/Dialect/Flow/IR/test/stream_ops.mlir @@ -16,8 +16,10 @@ flow.executable @dispatch_0 { func @fragment(%arg0 : tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { // CHECK: %[[WORKLOAD:.+]] = constant %cst = constant 4 : index - // CHECK: %0:2 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD]] : index, %arg2 = %arg0 : tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { - %0:2 = flow.ex.stream.fragment(%arg1 = %cst : index, %arg2 = %arg0 : tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { + // CHECK: %0:2 = flow.ex.stream.fragment(%[[WORKLOAD]], %arg0) : (index, tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) = + // CHECK-NEXT: (%arg1: index, %arg2: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { + %0:2 = flow.ex.stream.fragment(%cst, %arg0) : (index, tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) = + (%arg1 : index, %arg2 : tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { // CHECK-NEXT: flow.dispatch %1 = flow.dispatch @dispatch_0::@rgn_dispatch_0[%arg1] (%arg2) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: flow.return diff --git a/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir b/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir index 502c4184d4aa..570b255855d7 100644 --- a/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir +++ b/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir @@ -9,6 +9,8 @@ func @reshapeNoOp(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { return %0 : tensor<4x4xf32> } +// ----- + // CHECK-LABEL: @reshapeNoOpScalar func @reshapeNoOpScalar(%arg0 : tensor) -> tensor { // CHECK-NEXT: return %arg0 : tensor @@ -16,15 +18,6 @@ func @reshapeNoOpScalar(%arg0 : tensor) -> tensor { return %0 : tensor } -// CHECK-LABEL: @reshapeTransitive -func @reshapeTransitive(%arg0 : tensor<4x4xf32>) -> tensor<8x2xf32> { - %0 = flow.tensor.reshape %arg0 : tensor<4x4xf32> -> tensor<2x8xf32> - // CHECK-NEXT: %[[T:.+]] = flow.tensor.reshape %arg0 : tensor<4x4xf32> -> tensor<8x2xf32> - %1 = flow.tensor.reshape %0 : tensor<2x8xf32> -> tensor<8x2xf32> - // CHECK-NEXT: return %[[T]] : tensor<8x2xf32> - return %1 : tensor<8x2xf32> -} - // ----- // CHECK-LABEL: @loadConst @@ -38,6 +31,8 @@ func @loadConst() -> i32 { return %2 : i32 } +// ----- + // CHECK-LABEL: @loadConstScalar func @loadConstScalar() -> i32 { %0 = constant dense<4> : tensor @@ -63,6 +58,8 @@ func @storeConst() -> tensor<2x2xi32> { return %1 : tensor<2x2xi32> } +// ----- + // CHECK-LABEL: @storeConstScalar func @storeConstScalar() -> tensor { %0 = constant dense<0> : tensor @@ -84,6 +81,8 @@ func @splatConst() -> tensor<4xi32> { return %1 : tensor<4xi32> } +// ----- + // CHECK-LABEL: @splatConstScalar func @splatConstScalar() -> tensor { %0 = constant 4 : i32 @@ -104,6 +103,8 @@ func @cloneConst() -> tensor<4xi32> { return %1 : tensor<4xi32> } +// ----- + // CHECK-LABEL: @cloneDynamic func @cloneDynamic(%arg0 : tensor<4xi32>) -> tensor<4xi32> { %0 = flow.tensor.clone %arg0 : tensor<4xi32> @@ -122,6 +123,8 @@ func @sliceConst0D() -> tensor { return %1 : tensor } +// ----- + // CHECK-LABEL: @sliceConst1D func @sliceConst1D() -> tensor<1xi32> { %0 = constant dense<0> : tensor<1xi32> @@ -133,6 +136,8 @@ func @sliceConst1D() -> tensor<1xi32> { return %1 : tensor<1xi32> } +// ----- + // CHECK-LABEL: @sliceConst1DZeroLength func @sliceConst1DZeroLength() -> tensor<0xi32> { %0 = constant dense<0> : tensor<1xi32> @@ -143,6 +148,8 @@ func @sliceConst1DZeroLength() -> tensor<0xi32> { return %1 : tensor<0xi32> } +// ----- + // CHECK-LABEL: @sliceConst2D func @sliceConst2D() -> tensor<1x2xi32> { %0 = constant dense<[[0, 1, 2], [3, 4, 5]]> : tensor<2x3xi32> @@ -157,6 +164,8 @@ func @sliceConst2D() -> tensor<1x2xi32> { return %1 : tensor<1x2xi32> } +// ----- + // CHECK-LABEL: @sliceConst2DZeroLength1 func @sliceConst2DZeroLength1() -> tensor<1x0xi32> { %0 = constant dense<[[0, 1, 2], [3, 4, 5]]> : tensor<2x3xi32> @@ -168,6 +177,8 @@ func @sliceConst2DZeroLength1() -> tensor<1x0xi32> { return %1 : tensor<1x0xi32> } +// ----- + // CHECK-LABEL: @sliceConst2DZeroLength01 func @sliceConst2DZeroLength01() -> tensor<0x0xi32> { %0 = constant dense<[[0, 1, 2], [3, 4, 5]]> : tensor<2x3xi32> @@ -178,6 +189,8 @@ func @sliceConst2DZeroLength01() -> tensor<0x0xi32> { return %1 : tensor<0x0xi32> } +// ----- + // CHECK-LABEL: @sliceConst3D func @sliceConst3D() -> tensor<1x2x3xi32> { %0 = constant dense<[[[0, 1, 2], [3, 4, 5], [6, 7, 8]], [[9, 10, 11], [12, 13, 14], [15, 16, 17]]]> : tensor<2x3x3xi32> @@ -205,6 +218,8 @@ func @updateConst0D() -> tensor { return %2 : tensor } +// ----- + // CHECK-LABEL: @updateConst1D func @updateConst1D() -> tensor<1xi32> { %0 = constant dense<0> : tensor<1xi32> @@ -216,6 +231,8 @@ func @updateConst1D() -> tensor<1xi32> { return %2 : tensor<1xi32> } +// ----- + // CHECK-LABEL: @updateConst1DUpdateZeroSize func @updateConst1DUpdateZeroSize() -> tensor<1xi32> { %0 = constant dense<> : tensor<0xi32> @@ -227,6 +244,8 @@ func @updateConst1DUpdateZeroSize() -> tensor<1xi32> { return %2 : tensor<1xi32> } +// ----- + // CHECK-LABEL: @updateConst2DUpdate1x1 func @updateConst2DUpdate1x1() -> tensor<3x4xi32> { %0 = constant dense<[[12]]> : tensor<1x1xi32> @@ -240,6 +259,8 @@ func @updateConst2DUpdate1x1() -> tensor<3x4xi32> { return %2 : tensor<3x4xi32> } +// ----- + // CHECK-LABEL: @updateConst2DUpdate2x2 func @updateConst2DUpdate2x2() -> tensor<3x4xi32> { %0 = constant dense<[[12, 13], [14, 15]]> : tensor<2x2xi32> @@ -253,6 +274,8 @@ func @updateConst2DUpdate2x2() -> tensor<3x4xi32> { return %2 : tensor<3x4xi32> } +// ----- + // CHECK-LABEL: @updateConst3DUpdate1x2x3 func @updateConst3DUpdate1x2x3() -> tensor<2x3x3xi32> { %0 = constant dense<[[[18, 19, 20], [21, 22, 23]]]> : tensor<1x2x3xi32> @@ -268,6 +291,8 @@ func @updateConst3DUpdate1x2x3() -> tensor<2x3x3xi32> { return %2 : tensor<2x3x3xi32> } +// ----- + // CHECK-LABEL: @updateConst3DUpdate2x3x2 func @updateConst3DUpdate2x3x2() -> tensor<2x3x3xi32> { %0 = constant dense<[[[18, 19], [20, 21], [22, 23]], [[24, 25], [26, 27], [28, 29]]]> : tensor<2x3x2xi32> @@ -283,6 +308,8 @@ func @updateConst3DUpdate2x3x2() -> tensor<2x3x3xi32> { return %2 : tensor<2x3x3xi32> } +// ----- + // CHECK-LABEL: @updateReplace func @updateReplace(%arg0 : tensor<4xi32>, %arg1 : tensor<4xi32>) -> tensor<4xi32> { %c0 = constant 0 : index @@ -291,6 +318,8 @@ func @updateReplace(%arg0 : tensor<4xi32>, %arg1 : tensor<4xi32>) -> tensor<4xi3 return %0 : tensor<4xi32> } +// ----- + // CHECK-LABEL: @propogateStaticShapeOfTarget func @propogateStaticShapeOfTarget(%arg0 : tensor, %arg1 : f32) -> tensor { %c21 = constant 21 : index @@ -305,11 +334,13 @@ func @propogateStaticShapeOfTarget(%arg0 : tensor, %arg1 : f32) -> tens } : tensor // CHECK: %[[UPDATED:.+]] = flow.tensor.update %{{.+}}, %[[TARGET]] // CHECK: %[[RESULT:.+]] = tensor.cast %[[UPDATED]] : tensor<21x42xf32> to tensor - %1 = flow.tensor.update %arg0, %0[%c2, %c4] : tensor -> tensor + %1 = flow.tensor.update %arg0, %0[%c2, %c4] : tensor{%c21, %c42} -> tensor{%c21, %c42} // CHECK: return %[[RESULT]] return %1 : tensor } +// ----- + // CHECK-LABEL: @propogateStaticShapeOfUpdate func @propogateStaticShapeOfUpdate(%arg0 : tensor, %arg1 : f32) -> tensor { %c21 = constant 21 : index @@ -323,7 +354,7 @@ func @propogateStaticShapeOfUpdate(%arg0 : tensor, %arg1 : f32) -> tens tensor.yield %arg1 : f32 } : tensor // CHECK: %[[RESULT:.+]] = flow.tensor.update %[[UPDATE]] - %1 = flow.tensor.update %0, %arg0[%c2, %c4] : tensor -> tensor + %1 = flow.tensor.update %0, %arg0[%c2, %c4] : tensor{%c21, %c42} -> tensor{%c21, %c42} // CHECK: return %[[RESULT]] return %1 : tensor } diff --git a/iree/compiler/Dialect/Flow/IR/test/tensor_ops.mlir b/iree/compiler/Dialect/Flow/IR/test/tensor_ops.mlir index 1e45299ff38c..84fef0802390 100644 --- a/iree/compiler/Dialect/Flow/IR/test/tensor_ops.mlir +++ b/iree/compiler/Dialect/Flow/IR/test/tensor_ops.mlir @@ -16,6 +16,15 @@ func @tensorReshapeScalar(%arg0 : tensor) -> tensor { return %0 : tensor } +// CHECK-LABEL: @tensorReshapeDynamic +func @tensorReshapeDynamic(%arg0 : tensor) -> tensor { + %c4 = constant 4 : index + %c8 = constant 8 : index + // CHECK: %0 = flow.tensor.reshape %arg0 : tensor{%c4} -> tensor{%c8} + %0 = flow.tensor.reshape %arg0 : tensor{%c4} -> tensor{%c8} + return %0 : tensor +} + // ----- // CHECK-LABEL: @tensorLoad @@ -32,6 +41,14 @@ func @tensorLoadScalar(%arg0 : tensor) -> f32 { return %0 : f32 } +// CHECK-LABEL: @tensorLoadDynamic +func @tensorLoadDynamic(%arg0 : tensor, %arg1 : index, %arg2 : index) -> f32 { + %c4 = constant 4 : index + // CHECK: %0 = flow.tensor.load %arg0[%arg1, %arg2] : tensor{%c4} + %0 = flow.tensor.load %arg0[%arg1, %arg2] : tensor{%c4} + return %0 : f32 +} + // ----- // CHECK-LABEL: @tensorStore @@ -48,6 +65,14 @@ func @tensorStoreScalar(%arg0 : f32, %arg1 : tensor) -> tensor { return %0 : tensor } +// CHECK-LABEL: @tensorStoreDynamic +func @tensorStoreDynamic(%arg0 : tensor, %arg1 : index, %arg2 : index, %arg3 : f32) -> tensor { + %c4 = constant 4 : index + // CHECK: %0 = flow.tensor.store %arg3, %arg0[%arg1, %arg2] : tensor{%c4} + %0 = flow.tensor.store %arg3, %arg0[%arg1, %arg2] : tensor{%c4} + return %0 : tensor +} + // ----- // CHECK-LABEL: @tensorSplat @@ -64,6 +89,14 @@ func @tensorSplatScalar(%arg0 : f32) -> tensor { return %0 : tensor } +// CHECK-LABEL: @tensorSplatDynamic +func @tensorSplatDynamic(%arg0 : f32) -> tensor { + %c4 = constant 4 : index + // CHECK: %0 = flow.tensor.splat %arg0 : tensor{%c4} + %0 = flow.tensor.splat %arg0 : tensor{%c4} + return %0 : tensor +} + // ----- // CHECK-LABEL: @tensorClone @@ -80,6 +113,14 @@ func @tensorCloneScalar(%arg0 : tensor) -> tensor { return %0 : tensor } +// CHECK-LABEL: @tensorCloneDynamic +func @tensorCloneDynamic(%arg0 : tensor) -> tensor { + %c4 = constant 4 : index + // CHECK: %0 = flow.tensor.clone %arg0 : tensor{%c4} + %0 = flow.tensor.clone %arg0 : tensor{%c4} + return %0 : tensor +} + // ----- // CHECK-LABEL: @tensorSlice @@ -89,6 +130,15 @@ func @tensorSlice(%arg0 : tensor<4x4xf32>, %arg1 : index, %arg2 : index) -> tens return %0 : tensor<2x2xf32> } +// CHECK-LABEL: @tensorSliceDynamic +func @tensorSliceDynamic(%arg0 : tensor, %arg1 : index, %arg2 : index) -> tensor { + %c2 = constant 2 : index + %c4 = constant 4 : index + // CHECK: %0 = flow.tensor.slice %arg0[%arg1, %arg2 for %arg2, %arg1] : tensor{%c4} -> tensor{%c2} + %0 = flow.tensor.slice %arg0[%arg1, %arg2 for %arg2, %arg1] : tensor{%c4} -> tensor{%c2} + return %0 : tensor +} + // ----- // CHECK-LABEL: @tensorUpdate @@ -97,3 +147,13 @@ func @tensorUpdate(%arg0 : tensor<2x2xf32>, %arg1 : tensor<4x4xf32>, %arg2 : ind %0 = flow.tensor.update %arg0, %arg1[%arg2, %arg3] : tensor<2x2xf32> -> tensor<4x4xf32> return %0 : tensor<4x4xf32> } + +// CHECK-LABEL: @tensorUpdateDynamic +func @tensorUpdateDynamic(%arg0 : tensor, %arg1 : tensor, %arg2 : index, %arg3 : index) -> tensor { + %c1 = constant 1 : index + %c2 = constant 2 : index + %c3 = constant 3 : index + // CHECK: %0 = flow.tensor.update %arg0, %arg1[%arg2, %arg3] : tensor{%c1, %c2} -> tensor{%c3} + %0 = flow.tensor.update %arg0, %arg1[%arg2, %arg3] : tensor{%c1, %c2} -> tensor{%c3} + return %0 : tensor +} diff --git a/iree/compiler/Dialect/Flow/IR/test/types.mlir b/iree/compiler/Dialect/Flow/IR/test/types.mlir index 110a09c1fd77..dcae0a2362e1 100644 --- a/iree/compiler/Dialect/Flow/IR/test/types.mlir +++ b/iree/compiler/Dialect/Flow/IR/test/types.mlir @@ -3,25 +3,25 @@ // CHECK-LABEL: @dispatchTypes func @dispatchTypes( // CHECK-SAME: %arg0: !flow.dispatch.input - %arg0 : !flow.dispatch.input, + %arg0: !flow.dispatch.input, // CHECK-SAME: %arg1: !flow.dispatch.input<4x4xf32> - %arg1 : !flow.dispatch.input<4x4xf32>, + %arg1: !flow.dispatch.input<4x4xf32>, // CHECK-SAME: %arg2: !flow.dispatch.input<1x2x3x4x5x6xf32> - %arg2 : !flow.dispatch.input<1x2x3x4x5x6xf32>, + %arg2: !flow.dispatch.input<1x2x3x4x5x6xf32>, // CHECK-SAME: %arg3: !flow.dispatch.input - %arg3 : !flow.dispatch.input, + %arg3: !flow.dispatch.input, // CHECK-SAME: %arg4: !flow.dispatch.input<1x?x3xf32> - %arg4 : !flow.dispatch.input<1x?x3xf32>, + %arg4: !flow.dispatch.input<1x?x3xf32>, // CHECK-SAME: %arg5: !flow.dispatch.output - %arg5 : !flow.dispatch.output, + %arg5: !flow.dispatch.output, // CHECK-SAME: %arg6: !flow.dispatch.output<4x4xf32> - %arg6 : !flow.dispatch.output<4x4xf32>, + %arg6: !flow.dispatch.output<4x4xf32>, // CHECK-SAME: %arg7: !flow.dispatch.output<1x2x3x4x5x6xf32> - %arg7 : !flow.dispatch.output<1x2x3x4x5x6xf32>, + %arg7: !flow.dispatch.output<1x2x3x4x5x6xf32>, // CHECK-SAME: %arg8: !flow.dispatch.output - %arg8 : !flow.dispatch.output, + %arg8: !flow.dispatch.output, // CHECK-SAME: %arg9: !flow.dispatch.output<1x?x3xf32> - %arg9 : !flow.dispatch.output<1x?x3xf32> + %arg9: !flow.dispatch.output<1x?x3xf32> ) { return } diff --git a/iree/compiler/Dialect/Flow/Transforms/CreateBenchmarkFuncs.cpp b/iree/compiler/Dialect/Flow/Transforms/CreateBenchmarkFuncs.cpp index 352ae67fadb1..656158c8da20 100644 --- a/iree/compiler/Dialect/Flow/Transforms/CreateBenchmarkFuncs.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/CreateBenchmarkFuncs.cpp @@ -69,7 +69,7 @@ class CreateBenchmarkFuncs auto dummyWorkload = blockBuilder.create(loc, 0); auto dispatchOp = blockBuilder.create( loc, dispatchEntryOp, ValueRange{dummyWorkload}, - funcType.getResults(), args); + funcType.getResults(), ValueRange{}, args, ValueRange{}); blockBuilder.create(loc, dispatchOp.getResults()); } } diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp index 1a2fbccec1f3..739a3c1617c4 100644 --- a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp @@ -16,6 +16,7 @@ #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" #include "iree/compiler/Dialect/Flow/IR/FlowTypes.h" #include "iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.h" +#include "iree/compiler/Dialect/Shape/IR/Builders.h" #include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h" #include "iree/compiler/Dialect/Shape/IR/ShapeOps.h" #include "llvm/ADT/STLExtras.h" @@ -219,7 +220,9 @@ static std::pair buildOperandLessFlowDispatchWorkgroupOp(PatternRewriter &rewriter, Location loc, ArrayRef count, Operation *op) { auto dispatchOp = rewriter.create( - loc, count, op->getResultTypes(), ValueRange{}); + loc, count, op->getResultTypes(), /*result_dims=*/ValueRange{}, + /*operands=*/ValueRange{}, + /*operand_dims=*/ValueRange{}); Region ®ion = dispatchOp.body(); Block *block = ®ion.front(); Operation *clonedOp; @@ -446,8 +449,22 @@ static LogicalResult legalizeDispatchWorkgroupOperands( block.eraseArguments( llvm::to_vector<4>(llvm::seq(0, numOldBBArgs))); + // Gather the dynamic dimensions for all operands. + SmallVector operandDynamicDims; + OpBuilder builder(dispatchOp); + for (Value operand : valuesDefinedAbove) { + if (auto rt = operand.getType().dyn_cast()) { + for (unsigned i = 0; i < rt.getRank(); ++i) { + if (!rt.isDynamicDim(i)) continue; + auto dim = builder.createOrFold(dispatchOp.getLoc(), operand, i); + operandDynamicDims.push_back(dim); + } + } + } + // Set the values captured from above as the new operands. dispatchOp.operandsMutable().assign(llvm::to_vector<4>(valuesDefinedAbove)); + dispatchOp.operand_dimsMutable().assign(operandDynamicDims); return success(); } @@ -510,6 +527,13 @@ struct TileAndDistributeOnTensorsPattern count.resize(getNumTilableLoops(op)); auto workload = convertToWorkload(rewriter, loc, count); + // Capture dynamic result dimensions. + SmallVector resultDynamicDims; + for (auto result : linalgOp.outputs()) { + resultDynamicDims.append(Shape::buildOrFindDynamicDimsForValue( + linalgOp.getLoc(), result, rewriter)); + } + // Note: DispatchOutputStoreOp generated by the // `buildOperandLessFlowDispatchWorkgroupOp` is an abstraction jump that // consumes the SSA value produced by `clonedOp` but it does not comply with @@ -518,10 +542,11 @@ struct TileAndDistributeOnTensorsPattern // output tensors". Similarly to sequentialized SPMD loops, the semantics // is valid assuming a sequential ordering of execution. After destructive // update rewrites, the abstraction gap disappears. - auto en = - buildOperandLessFlowDispatchWorkgroupOp(rewriter, loc, workload, op); - linalg::LinalgOp clonedLinalgOp = cast(en.second); + auto en = buildOperandLessFlowDispatchWorkgroupOp(rewriter, loc, workload, + linalgOp); IREE::Flow::DispatchWorkgroupsOp dispatchOp = en.first; + linalg::LinalgOp clonedLinalgOp = cast(en.second); + dispatchOp.result_dimsMutable().assign(resultDynamicDims); // Scoped within DispatchWorkgroupOp. OpBuilder::InsertionGuard g(rewriter); @@ -546,7 +571,7 @@ struct TileAndDistributeOnTensorsPattern tiledLinalgOp.op.getOperation()->removeAttr(kRootOpAttr); rewriter.replaceOpWithIf( - op, dispatchOp.getOperation()->getResults(), + op, dispatchOp.getResults(), [&](OpOperand &operand) { return !isa(operand.getOwner()); }); return success(); } @@ -554,7 +579,7 @@ struct TileAndDistributeOnTensorsPattern /// The workload is computed based on the problem size. For a given operation, /// return the problem size. -static Optional> getProblemSize(PatternRewriter &rewriter, +static Optional> getResultShape(PatternRewriter &rewriter, Operation *op) { Location loc = op->getLoc(); auto getShapeOfShapedTypeVal = [&](Value v) -> SmallVector { @@ -599,11 +624,13 @@ struct MakeDispatchWorkgroupsOp : public RewritePattern { })) { return failure(); } + // The workgroup count is based on the result shape. if (op->getNumResults() != 1) return failure(); - Optional> countOpt = getProblemSize(rewriter, op); - if (!countOpt) return failure(); - SmallVector count = *countOpt; + Optional> resultShapeOpt = + getResultShape(rewriter, op); + if (!resultShapeOpt) return failure(); + SmallVector resultShape = *resultShapeOpt; // TODO(ravishankarm): For now the Flow -> HAL conversion only handles // workload count of 3, though it should be generalized. For now making sure @@ -611,6 +638,7 @@ struct MakeDispatchWorkgroupsOp : public RewritePattern { // workloads for all higher dimensions greater than or equal to // kNumMaxParallelDims. Location loc = op->getLoc(); + SmallVector count = resultShape; if (count.size() > kNumMaxParallelDims) { unsigned numSymbols = 0; AffineExpr expr = rewriter.getAffineSymbolExpr(numSymbols++); @@ -625,9 +653,22 @@ struct MakeDispatchWorkgroupsOp : public RewritePattern { ArrayRef(count).take_back(kNumMaxParallelDims)); } auto workload = convertToWorkload(rewriter, loc, count); + + // Capture dynamic result dimensions. + assert(op->getNumResults() == 1 && "currently assuming a single result"); + auto resultType = op->getResult(0).getType().cast(); + SmallVector resultDynamicDims; + for (unsigned i = 0; i < resultType.getRank(); ++i) { + if (resultType.isDynamicDim(i)) { + resultDynamicDims.push_back(resultShape[i]); + } + } + auto en = buildOperandLessFlowDispatchWorkgroupOp(rewriter, op->getLoc(), workload, op); IREE::Flow::DispatchWorkgroupsOp dispatchOp = en.first; + dispatchOp.result_dimsMutable().assign(resultDynamicDims); + rewriter.replaceOpWithIf(op, dispatchOp.getOperation()->getResults(), [&](OpOperand &operand) { Operation *user = operand.getOwner(); diff --git a/iree/compiler/Dialect/Flow/Transforms/FormStreams.cpp b/iree/compiler/Dialect/Flow/Transforms/FormStreams.cpp index 1d5b8f50903b..21c51f07831b 100644 --- a/iree/compiler/Dialect/Flow/Transforms/FormStreams.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/FormStreams.cpp @@ -17,6 +17,7 @@ #include "iree/compiler/Dialect/Flow/IR/FlowDialect.h" #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" #include "iree/compiler/Dialect/Flow/Transforms/Passes.h" +#include "iree/compiler/Dialect/Shape/IR/Builders.h" #include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h" #include "iree/compiler/Dialect/Shape/IR/ShapeOps.h" #include "iree/compiler/Dialect/Shape/Utils/TypeConversion.h" @@ -61,21 +62,6 @@ static inline bool usefulStreamWork(ArrayRef currentStreamOps) { return llvm::any_of(currentStreamOps, usefulStreamOp); } -// Expand any compound types to primitive types in the stream fragment. -static void expandFragmentToPrimitiveTypes(ExStreamFragmentOp fragmentOp) { - auto loc = fragmentOp.getLoc(); - Block *entryBlock = &fragmentOp.body().front(); - auto &typeExpander = Shape::getShapeToPrimitiveTypeExpander(); - OpBuilder expandBuilder(fragmentOp.getContext()); - (void)typeExpander.expandBlockSignature(loc, entryBlock, expandBuilder); - SmallVector origFragmentArgs(fragmentOp.args()); - SmallVector newFragmentArgs; - expandBuilder.setInsertionPoint(fragmentOp); - (void)typeExpander.expandSourceValuesToTarget(loc, origFragmentArgs, - newFragmentArgs, expandBuilder); - fragmentOp.getOperation()->setOperands(newFragmentArgs); -} - // Temporary hack to get the experimental stream ops constructed. In the future // this will run an analysis to identify compatible dispatches across the entire // function CFG, create the streams, and then thread the streams through the CFG @@ -130,10 +116,10 @@ class FormStreamsPass : public PassWrapper { llvm::SmallSetVector streamOpSet{streamOps.begin(), streamOps.end()}; SmallVector fragmentOperands; + SmallVector fragmentOperandDims; SmallVector fragmentResults; + SmallVector fragmentResultDims; SmallVector fragmentResultTypes; - SmallVector tieShapeOps; - SmallVector outsideTieShapeOperands; for (auto *op : streamOps) { for (auto operand : op->getOperands()) { if (std::find(fragmentOperands.begin(), fragmentOperands.end(), @@ -141,16 +127,10 @@ class FormStreamsPass : public PassWrapper { if (!operand.getDefiningOp() || !streamOpSet.count(operand.getDefiningOp())) { fragmentOperands.push_back(operand); - - auto operandDefiningOp = operand.getDefiningOp(); - if (operandDefiningOp && - llvm::isa(operandDefiningOp)) { - tieShapeOps.push_back(operand.getDefiningOp()); - auto definingOp = - dyn_cast(operand.getDefiningOp()); - for (auto arg : definingOp.getOperands()) { - outsideTieShapeOperands.push_back(arg); - } + if (operand.getType().isa()) { + auto dynamicDims = Shape::buildOrFindDynamicDimsForValue( + fragmentLoc, operand, blockBuilder); + fragmentOperandDims.append(dynamicDims); } } } @@ -167,32 +147,28 @@ class FormStreamsPass : public PassWrapper { if (!onlyStreamUses) { fragmentResults.push_back(result); fragmentResultTypes.push_back(result.getType()); + if (result.getType().isa()) { + auto dynamicDims = Shape::buildOrFindDynamicDimsForValue( + fragmentLoc, result, blockBuilder); + fragmentResultDims.append(dynamicDims); + } } } } - // TODO(Tao Peng): pass args(operand and shape) which need by outside - // tie_shape into fragment body, and ignore the tie_shape arg passed into - // the fragment, it will not be used, and will be deleted by canonicalizer - // later. - outsideTieShapeOperands.append(fragmentOperands.begin(), - fragmentOperands.end()); - fragmentOperands = outsideTieShapeOperands; - // Create the fragment and clone in all of the ops. auto fragmentOp = blockBuilder.create( - fragmentLoc, fragmentResultTypes, fragmentOperands); + fragmentLoc, fragmentResultTypes, fragmentResultDims, fragmentOperands, + fragmentOperandDims); auto *entryBlock = new Block(); fragmentOp.body().getBlocks().push_back(entryBlock); - entryBlock->addArguments(llvm::to_vector<8>(fragmentOp.getOperandTypes())); + entryBlock->addArguments(TypeRange(fragmentOp.operands())); BlockAndValueMapping mapping; - for (auto arg : entryBlock->getArguments()) { + for (unsigned i = 0; i < fragmentOperands.size(); ++i) { + auto arg = entryBlock->getArgument(i); mapping.map(fragmentOperands[arg.getArgNumber()], arg); } OpBuilder fragmentBuilder = OpBuilder::atBlockEnd(entryBlock); - for (auto *op : tieShapeOps) { - fragmentBuilder.clone(*op, mapping); - } for (auto *op : streamOps) { fragmentBuilder.clone(*op, mapping); } @@ -201,20 +177,18 @@ class FormStreamsPass : public PassWrapper { llvm::to_vector<8>(llvm::map_range(fragmentResults, [&](Value value) { return mapping.lookup(value); }))); - for (auto resultOldNew : - llvm::zip(fragmentResults, fragmentOp.getResults())) { + for (auto resultOldNew : llvm::zip(fragmentResults, fragmentOp.results())) { auto oldValue = std::get<0>(resultOldNew); auto newValue = std::get<1>(resultOldNew); oldValue.replaceAllUsesWith(newValue); } // Erase the ops from the block now that we've cloned them. + // Note the backwards order as the ops may have dependencies on each other + // and we have to erase the consumers before the producers. for (auto *op : llvm::reverse(streamOps)) { op->erase(); } - - // Expand any shape types to corresponding primitives. - expandFragmentToPrimitiveTypes(fragmentOp); } }; diff --git a/iree/compiler/Dialect/Flow/Transforms/IdentifyDispatchRegions2.cpp b/iree/compiler/Dialect/Flow/Transforms/IdentifyDispatchRegions2.cpp index dadac5d560a5..3d6e70144348 100644 --- a/iree/compiler/Dialect/Flow/Transforms/IdentifyDispatchRegions2.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/IdentifyDispatchRegions2.cpp @@ -248,7 +248,6 @@ LogicalResult fuseInputs(DispatchRegion &dispatchRegion, // makes this simple check safe. // The dispatch region must be optimized to remove unused arguments // resulting from this fusion. - DispatchRegionOp::dceOperandsAndResults(dispatchRegion.op); if (nextOp->use_empty()) { nextOp->erase(); } @@ -386,8 +385,11 @@ LogicalResult processBlock(Block &block, OpDispatchPolicy &policy) { if (failed(fuseInputs(*dispatchRegion, policy))) return failure(); // Ensure all unused operands and results are dce'd. - DispatchRegionOp::dceOperandsAndResults(dispatchRegion->op); - hoistDispatchRegionMetadataOps(*dispatchRegion, policy); + // Note that this may delete the op itself if it is unused. + optimizeClosureOp(dispatchRegion->op); + if (dispatchRegion->op) { + hoistDispatchRegionMetadataOps(*dispatchRegion, policy); + } } return success(); } diff --git a/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp b/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp index ddda1be83c9a..7816d6700178 100644 --- a/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp @@ -17,6 +17,7 @@ #include "iree/compiler/Dialect/Flow/Analysis/Dispatchability.h" #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" #include "iree/compiler/Dialect/Flow/Utils/DispatchUtils.h" +#include "iree/compiler/Dialect/Shape/IR/Builders.h" #include "iree/compiler/Dialect/Shape/IR/ShapeOps.h" #include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h" #include "iree/compiler/Dialect/Shape/Utils/TypeConversion.h" @@ -76,10 +77,26 @@ LogicalResult convertToDispatchOp(DispatchRegionOp regionOp, getTensorTypeArgs(newArgs)); } + SmallVector operandDynamicDims; + for (auto operand : regionOp.args()) { + if (operand.getType().isa()) { + operandDynamicDims.append(Shape::buildOrFindDynamicDimsForValue( + regionOp.getLoc(), operand, builder)); + } + } + SmallVector resultDynamicDims; + for (auto result : regionOp.results()) { + if (result.getType().isa()) { + resultDynamicDims.append(Shape::buildOrFindDynamicDimsForValue( + regionOp.getLoc(), result, builder)); + } + } + // Create the dispatch op to the executable function. auto dispatchOp = builder.create( regionOp.getLoc(), entryPointOp, ValueRange{regionOp.workload()}, - outlinedFuncOp.getType().getResults(), newArgs); + outlinedFuncOp.getType().getResults(), resultDynamicDims, newArgs, + operandDynamicDims); if (traceDispatchTensors) { std::string str = "Output for " + std::string(outlinedFuncOp.getName()); diff --git a/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions2.cpp b/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions2.cpp index 4b704c6788d6..1ffbd9c86aca 100644 --- a/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions2.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions2.cpp @@ -37,28 +37,6 @@ namespace IREE { namespace Flow { namespace { -// Expands dynamic dimensions of a shaped type to their individual values. -// Will walk the shape IR hierarchy to resolve the dimensions as possible. -static void insertDynamicShapeDimOperands(DispatchWorkgroupsOp regionOp, - Value value, - SmallVectorImpl &newOperands, - OpBuilder &builder) { - auto shapedType = value.getType().cast(); - if (shapedType.hasStaticShape()) return; - - // NOTE: this may insert new shape values at |builder|, which is prior to our - // dispatch operation. All new values that are built can only depend on SSA - // values that are defined prior to the region op. - auto shapeValue = Shape::buildOrFindRankedShapeForValue( - regionOp.getLoc(), value, builder.getIndexType(), builder); - for (int dim = 0, e = shapedType.getRank(); dim < e; ++dim) { - if (shapedType.isDynamicDim(dim)) { - newOperands.push_back(builder.create( - regionOp.getLoc(), shapeValue, dim)); - } - } -} - // Converts a dispatch region op into a dispatch op to the outlined region. static LogicalResult convertToDispatchOp(DispatchWorkgroupsOp regionOp, ExecutableOp executableOp, @@ -67,25 +45,37 @@ static LogicalResult convertToDispatchOp(DispatchWorkgroupsOp regionOp, OpBuilder builder(regionOp); // Perform shape to primitive type expansion. + // NOTE: this may insert new shape values at |builder|, which is prior to + // our dispatch operation. All new values that are built can only depend + // on SSA values that are defined prior to the region op. SmallVector newOperands; + SmallVector operandDynamicDims; + SmallVector resultDynamicDims; for (auto operand : regionOp.operands()) { newOperands.push_back(operand); } for (auto operand : regionOp.operands()) { - if (operand.getType().isa()) { - insertDynamicShapeDimOperands(regionOp, operand, newOperands, builder); + if (operand.getType().isa()) { + auto dynamicDims = Shape::buildOrFindDynamicDimsForValue( + regionOp.getLoc(), operand, builder); + operandDynamicDims.append(dynamicDims); + newOperands.append(dynamicDims); } } for (auto result : regionOp.results()) { - if (result.getType().isa()) { - insertDynamicShapeDimOperands(regionOp, result, newOperands, builder); + if (result.getType().isa()) { + auto dynamicDims = Shape::buildOrFindDynamicDimsForValue( + regionOp.getLoc(), result, builder); + resultDynamicDims.append(dynamicDims); + newOperands.append(dynamicDims); } } // Create the dispatch op to the executable function. auto dispatchOp = builder.create( regionOp.getLoc(), entryPointOp, regionOp.workgroup_count(), - regionOp.getResultTypes(), newOperands); + regionOp.getResultTypes(), resultDynamicDims, newOperands, + operandDynamicDims); // Replace uses of the existing results with the new results. for (int i = 0; i < regionOp.getNumResults(); ++i) { diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp index f5a60db96162..50204c445600 100644 --- a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp @@ -195,8 +195,7 @@ void buildFlowTransformPassPipeline(OpPassManager &passManager) { passManager.addNestedPass(createCanonicalizerPass()); addHLOToLinalgOnTensorsPasses(passManager, clEnableLinalgOnTensorsDispatch); passManager.addNestedPass(createDispatchLinalgOnTensorsPass()); - passManager.addPass(createCanonicalizerPass()); - passManager.addPass(createCSEPass()); + passManager.addNestedPass(createCanonicalizerPass()); } // First perform module-level analysis that following passes will use to query @@ -258,15 +257,16 @@ void buildFlowTransformPassPipeline(OpPassManager &passManager) { passManager.addNestedPass(createCanonicalizerPass()); passManager.addNestedPass(IREE::Flow::createFormStreamsPass()); - // Forming streams involves a fair amount of subgraph stitching, which can - // cause duplication. Run CSE to collapse. - passManager.addNestedPass(createCanonicalizerPass()); - passManager.addNestedPass(createCSEPass()); // Prior to leaving the pipeline we need to clean things up for following // layers. These transforms may be undone by subsequent CSE/folding passes. passManager.addPass(createOutlineLargeConstantsPass()); + // Forming streams involves a fair amount of subgraph stitching, which can + // cause duplication. Run CSE to collapse. + passManager.addNestedPass(createCanonicalizerPass()); + passManager.addNestedPass(createCSEPass()); + // Symbol DCE any remaining variables/functions that are now no longer // required. passManager.addPass(createSymbolDCEPass()); diff --git a/iree/compiler/Dialect/Flow/Transforms/RematerializeDispatchConstants.cpp b/iree/compiler/Dialect/Flow/Transforms/RematerializeDispatchConstants.cpp index 43a184240e8a..4f82b470bf32 100644 --- a/iree/compiler/Dialect/Flow/Transforms/RematerializeDispatchConstants.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/RematerializeDispatchConstants.cpp @@ -117,25 +117,22 @@ Value cloneOpTreeIntoBlock(Value sourceValue, Block *targetBlock, // Modify the second operand of the SegmentSize attribute // TODO(ataei): Remove this once we have flow.dispatch.workgroups only here. template -void modifyOperandSegmentSizeAttr(DispatchOpType dispatchOp, int32_t argCount); +void dropOperands(DispatchOpType dispatchOp, ArrayRef deadIndices); template <> -void modifyOperandSegmentSizeAttr(DispatchRegionOp dispatchRegionOp, - int32_t argCount) {} +void dropOperands(DispatchRegionOp dispatchRegionOp, + ArrayRef deadIndices) { + for (unsigned argIndex : llvm::reverse(deadIndices)) { + dispatchRegionOp.argsMutable().erase(argIndex); + } +} template <> -void modifyOperandSegmentSizeAttr(DispatchWorkgroupsOp dispatchWorkgroupsOp, - int32_t argCount) { - dispatchWorkgroupsOp.getOperation()->setAttr( - DispatchWorkgroupsOp::getOperandSegmentSizeAttr(), - DenseIntElementsAttr::get( - VectorType::get( - 2, IntegerType::get(dispatchWorkgroupsOp.getContext(), 32)), - ArrayRef( - {static_cast( - dispatchWorkgroupsOp.workgroup_count().size()), - static_cast(dispatchWorkgroupsOp.operands().size() - - argCount)}))); +void dropOperands(DispatchWorkgroupsOp dispatchWorkgroupsOp, + ArrayRef deadIndices) { + for (unsigned argIndex : llvm::reverse(deadIndices)) { + dispatchWorkgroupsOp.operandsMutable().erase(argIndex); + } } // Inlines use of the given |value| from outside of a dispatch region to inside @@ -170,11 +167,9 @@ LogicalResult inlineDispatchRegionOperandsUsingValue(DispatchOpType dispatchOp, // Remove the dispatch region args and the block args that have been // replaced. for (unsigned argIndex : llvm::reverse(argIndices)) { - dispatchOp.getOperation()->eraseOperand( - dispatchOp.mapArgOperandToOpOperand(argIndex)); entryBlock.eraseArgument(argIndex); } - modifyOperandSegmentSizeAttr(dispatchOp, argIndices.size()); + dropOperands(dispatchOp, argIndices); return success(); } diff --git a/iree/compiler/Dialect/Flow/Transforms/test/create_benchmark_funcs.mlir b/iree/compiler/Dialect/Flow/Transforms/test/create_benchmark_funcs.mlir index 6d785ebbd693..6f03139ee293 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/create_benchmark_funcs.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/create_benchmark_funcs.mlir @@ -12,8 +12,8 @@ module { // CHECK: func @two_dispatch_ex_dispatch_0_entry // CHECK: %{{.+}} = flow.variable.load @[[IN0_0]] : tensor<5x3xf32> // CHECK: %{{.+}} = flow.variable.load @[[IN0_1]] : tensor<3x5xf32> -// CHECK: %[[RES:.+]] = flow.ex.stream.fragment({{.+}}) -> tensor<5x5xf32> { -// CHECK: %[[DISPATCH_RES:.+]] = flow.dispatch @two_dispatch_ex_dispatch_0::@two_dispatch_ex_dispatch_0[%{{.+}}] (%{{.+}}, %{{.+}}) : (tensor<5x3xf32>, tensor<3x5xf32>) -> tensor<5x5xf32> +// CHECK: %[[RES:.+]] = flow.ex.stream.fragment({{.+}}) -> (tensor<5x5xf32>) = +// CHECK: %[[DISPATCH_RES:.+]] = flow.dispatch @two_dispatch_ex_dispatch_0::@two_dispatch_ex_dispatch_0[%{{.+}}](%{{.+}}, %{{.+}}) : (tensor<5x3xf32>, tensor<3x5xf32>) -> (tensor<5x5xf32>) // CHECK: flow.return %[[DISPATCH_RES]] : tensor<5x5xf32> // CHECK: return %[[RES]] : tensor<5x5xf32> // @@ -22,8 +22,8 @@ module { // CHECK: func @two_dispatch_ex_dispatch_1_entry // CHECK: %{{.+}} = flow.variable.load @[[IN1_0]] : tensor<3x5xf32> // CHECK: %{{.+}} = flow.variable.load @[[IN1_1]] : tensor<5x5xf32> -// CHECK: %[[RES:.+]] = flow.ex.stream.fragment({{.+}}) -> tensor<3x5xf32> -// CHECK: %[[DISPATCH_RES:.+]] = flow.dispatch @two_dispatch_ex_dispatch_1::@two_dispatch_ex_dispatch_1[%{{.+}}] (%{{.+}}, %{{.+}}) : (tensor<3x5xf32>, tensor<5x5xf32>) -> tensor<3x5xf32> +// CHECK: %[[RES:.+]] = flow.ex.stream.fragment({{.+}}) -> (tensor<3x5xf32>) +// CHECK: %[[DISPATCH_RES:.+]] = flow.dispatch @two_dispatch_ex_dispatch_1::@two_dispatch_ex_dispatch_1[%{{.+}}](%{{.+}}, %{{.+}}) : (tensor<3x5xf32>, tensor<5x5xf32>) -> (tensor<3x5xf32>) // CHECK: flow.return %[[DISPATCH_RES]] : tensor<3x5xf32> // CHECK: return %[[RES]] : tensor<3x5xf32> // @@ -32,7 +32,7 @@ module { // CHECK: func @two_dispatch_dummy_args() // CHECK: %{{.+}} = flow.variable.load @[[MAIN_IN_0]] : tensor<5x3xf32> // CHECK: %{{.+}} = flow.variable.load @[[MAIN_IN_1]] : tensor<3x5xf32> -// CHECK: flow.ex.stream.fragment({{.+}}) -> (tensor<5x5xf32>, tensor<3x5xf32>) { +// CHECK: flow.ex.stream.fragment({{.+}}) -> (tensor<5x5xf32>, tensor<3x5xf32>) = // CHECK: %[[DISPATCH_RES1:.+]] = flow.dispatch // CHECK: %[[DISPATCH_RES2:.+]] = flow.dispatch // CHECK: flow.return %[[DISPATCH_RES1]], %[[DISPATCH_RES2]] : tensor<5x5xf32>, tensor<3x5xf32> diff --git a/iree/compiler/Dialect/Flow/Transforms/test/deduplicate_executables.mlir b/iree/compiler/Dialect/Flow/Transforms/test/deduplicate_executables.mlir index 60d38175b83c..916becdb65bd 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/deduplicate_executables.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/deduplicate_executables.mlir @@ -13,7 +13,7 @@ flow.executable @single_executable_ex_0 { // CHECK-LABEL: func @single_executable func @single_executable(%arg0: tensor<4xf32>) -> tensor<4xf32> { %c4 = constant 4 : index - // CHECK: %0 = flow.dispatch @single_executable_ex_0::@single_executable_entry_0[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: %0 = flow.dispatch @single_executable_ex_0::@single_executable_entry_0[%c4](%arg0) : (tensor<4xf32>) -> (tensor<4xf32>) %0 = flow.dispatch @single_executable_ex_0::@single_executable_entry_0[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } @@ -53,11 +53,11 @@ flow.executable @duplicate_executables_ex_2 { // CHECK-LABEL: func @duplicate_executables func @duplicate_executables(%arg0: tensor<4xf32>) -> tensor<4xf32> { %c4 = constant 4 : index - // CHECK: %0 = flow.dispatch @duplicate_executables_ex_0::@duplicate_executables_entry_0[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: %0 = flow.dispatch @duplicate_executables_ex_0::@duplicate_executables_entry_0[%c4](%arg0) : (tensor<4xf32>) -> (tensor<4xf32>) %0 = flow.dispatch @duplicate_executables_ex_0::@duplicate_executables_entry_0[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> - // CHECK: %1 = flow.dispatch @duplicate_executables_ex_0::@duplicate_executables_entry_0[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: %1 = flow.dispatch @duplicate_executables_ex_0::@duplicate_executables_entry_0[%c4](%arg0) : (tensor<4xf32>) -> (tensor<4xf32>) %1 = flow.dispatch @duplicate_executables_ex_1::@duplicate_executables_entry_1[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> - // CHECK: %2 = flow.dispatch @duplicate_executables_ex_2::@duplicate_executables_entry_2[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: %2 = flow.dispatch @duplicate_executables_ex_2::@duplicate_executables_entry_2[%c4](%arg0) : (tensor<4xf32>) -> (tensor<4xf32>) %2 = flow.dispatch @duplicate_executables_ex_2::@duplicate_executables_entry_2[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } @@ -87,9 +87,9 @@ flow.executable @same_ops_diff_operands_ex_1 { // CHECK-LABEL: func @same_ops_diff_operands func @same_ops_diff_operands(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi32> { %c4 = constant 4 : index - // CHECK: %0 = flow.dispatch @same_ops_diff_operands_ex_0::@entry_0[%c4] (%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + // CHECK: %0 = flow.dispatch @same_ops_diff_operands_ex_0::@entry_0[%c4](%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> (tensor<2xi32>) %0 = flow.dispatch @same_ops_diff_operands_ex_0::@entry_0[%c4] (%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - // CHECK: %1 = flow.dispatch @same_ops_diff_operands_ex_1::@entry_1[%c4] (%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + // CHECK: %1 = flow.dispatch @same_ops_diff_operands_ex_1::@entry_1[%c4](%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> (tensor<2xi32>) %1 = flow.dispatch @same_ops_diff_operands_ex_1::@entry_1[%c4] (%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> return %0 : tensor<2xi32> } @@ -129,13 +129,13 @@ flow.executable @multiple_entry_points_ex_1 { // CHECK-LABEL: func @multiple_entry_points func @multiple_entry_points(%arg0: tensor<4xf32>) -> tensor<4xf32> { %c4 = constant 4 : index - // CHECK: %0 = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_0[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: %0 = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_0[%c4](%arg0) : (tensor<4xf32>) -> (tensor<4xf32>) %0 = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_0[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> - // CHECK: %1 = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_1[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: %1 = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_1[%c4](%arg0) : (tensor<4xf32>) -> (tensor<4xf32>) %1 = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_1[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> - // CHECK: %2 = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_0[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: %2 = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_0[%c4](%arg0) : (tensor<4xf32>) -> (tensor<4xf32>) %2 = flow.dispatch @multiple_entry_points_ex_1::@multiple_entry_points_1_entry_0[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> - // CHECK: %3 = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_1[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: %3 = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_1[%c4](%arg0) : (tensor<4xf32>) -> (tensor<4xf32>) %3 = flow.dispatch @multiple_entry_points_ex_1::@multiple_entry_points_1_entry_1[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } @@ -165,9 +165,9 @@ flow.executable @different_types_int_ex { // CHECK-LABEL: func @different_types func @different_types(%arg0: tensor<4xf32>) -> tensor<4xi1> { %c4 = constant 4 : index - // CHECK: %0 = flow.dispatch @different_types_float_ex::@different_types_float_entry[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xi1> + // CHECK: %0 = flow.dispatch @different_types_float_ex::@different_types_float_entry[%c4](%arg0) : (tensor<4xf32>) -> (tensor<4xi1>) %0 = flow.dispatch @different_types_float_ex::@different_types_float_entry[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xi1> - // CHECK: %1 = flow.dispatch @different_types_int_ex::@different_types_int_entry[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xi1> + // CHECK: %1 = flow.dispatch @different_types_int_ex::@different_types_int_entry[%c4](%arg0) : (tensor<4xf32>) -> (tensor<4xi1>) %1 = flow.dispatch @different_types_int_ex::@different_types_int_entry[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xi1> return %0 : tensor<4xi1> } @@ -222,11 +222,11 @@ flow.executable @nested_ops_ex_2 { // CHECK-LABEL: func @nested_ops func @nested_ops(%arg0: tensor<1x4xi32>) -> tensor<1xi32> { %c4 = constant 4 : index - // CHECK: %0 = flow.dispatch @nested_ops_ex_0::@nested_ops_entry_0[%c4] (%arg0) : (tensor<1x4xi32>) -> tensor<1xi32> + // CHECK: %0 = flow.dispatch @nested_ops_ex_0::@nested_ops_entry_0[%c4](%arg0) : (tensor<1x4xi32>) -> (tensor<1xi32>) %0 = flow.dispatch @nested_ops_ex_0::@nested_ops_entry_0[%c4] (%arg0) : (tensor<1x4xi32>) -> tensor<1xi32> - // CHECK: %1 = flow.dispatch @nested_ops_ex_0::@nested_ops_entry_0[%c4] (%arg0) : (tensor<1x4xi32>) -> tensor<1xi32> + // CHECK: %1 = flow.dispatch @nested_ops_ex_0::@nested_ops_entry_0[%c4](%arg0) : (tensor<1x4xi32>) -> (tensor<1xi32>) %1 = flow.dispatch @nested_ops_ex_0::@nested_ops_entry_0[%c4] (%arg0) : (tensor<1x4xi32>) -> tensor<1xi32> - // CHECK: %2 = flow.dispatch @nested_ops_ex_2::@nested_ops_entry_2[%c4] (%arg0) : (tensor<1x4xi32>) -> tensor<1xi32> + // CHECK: %2 = flow.dispatch @nested_ops_ex_2::@nested_ops_entry_2[%c4](%arg0) : (tensor<1x4xi32>) -> (tensor<1xi32>) %2 = flow.dispatch @nested_ops_ex_2::@nested_ops_entry_2[%c4] (%arg0) : (tensor<1x4xi32>) -> tensor<1xi32> return %0 : tensor<1xi32> } diff --git a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir index 3895e67c9ed1..d0d16093f5a7 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir @@ -17,12 +17,12 @@ func @tensor() -> tensor<2x4xf32> { %C = iree.unfoldable_constant dense<1000.0> : tensor<2x4xf32> // %[[C2]] will be handled by a later RematerializeDispatchConstants - // CHECK: flow.dispatch.workgroups[%[[C4wg]], %[[C2wg]], %[[C1wg]]] (%[[outerA]], %[[outerB]], %[[outerC]]) : - // CHECK-SAME: (tensor<2x3xf32>, tensor<3x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> = - // CHECK-SAME: (%[[A:[0-9a-z]*]] : !flow.dispatch.input<2x3xf32>, - // CHECK-SAME: %[[B:[0-9a-z]*]] : !flow.dispatch.input<3x4xf32>, - // CHECK-SAME: %[[C:[0-9a-z]*]] : !flow.dispatch.input<2x4xf32>, - // CHECK-SAME: %[[OUT:[0-9a-z]*]] : !flow.dispatch.output<2x4xf32>) { + // CHECK: flow.dispatch.workgroups[%[[C4wg]], %[[C2wg]], %[[C1wg]]](%[[outerA]], %[[outerB]], %[[outerC]]) : + // CHECK-SAME: (tensor<2x3xf32>, tensor<3x4xf32>, tensor<2x4xf32>) -> (tensor<2x4xf32>) = + // CHECK-NEXT: (%[[A:[0-9a-z]*]]: !flow.dispatch.input<2x3xf32>, + // CHECK-SAME: %[[B:[0-9a-z]*]]: !flow.dispatch.input<3x4xf32>, + // CHECK-SAME: %[[C:[0-9a-z]*]]: !flow.dispatch.input<2x4xf32>, + // CHECK-SAME: %[[OUT:[0-9a-z]*]]: !flow.dispatch.output<2x4xf32>) { // CHECK-DAG: %[[C0:.*]] = constant 0 : index // CHECK-DAG: %[[C1:.*]] = constant 1 : index // CHECK-DAG: %[[C2:.*]] = constant 2 : index @@ -63,6 +63,8 @@ func @tensor() -> tensor<2x4xf32> { return %E : tensor<2x4xf32> } +// ----- + // CHECK-LABEL: func @tensor2 func @tensor2(%A: tensor, %B: tensor, %C: tensor) -> tensor attributes {iree.module.export} @@ -89,6 +91,8 @@ func @tensor2(%A: tensor, %B: tensor, %C: tensor) return %D: tensor } +// ----- + // CHECK-LABEL: func @tensor3 func @tensor3(%A: tensor, %B: tensor, %C: tensor) -> tensor attributes {iree.module.export} @@ -115,6 +119,7 @@ func @tensor3(%A: tensor, %B: tensor, %C: tensor) return %D: tensor } +// ----- // CHECK-LABEL: func @tensor4 func @tensor4(%A: tensor, %B: tensor, %C: tensor) @@ -143,6 +148,8 @@ func @tensor4(%A: tensor, %B: tensor, %C: tensor) return %D: tensor } +// ----- + // CHECK: func @tensor5 // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor @@ -155,9 +162,9 @@ func @tensor5(%A: tensor, %B: tensor, %C: tensor) // CHECK-DAG: %[[C1:.+]] = constant 1 : index // CHECK-DAG: %[[D0:.+]] = dim %[[ARG2]], %[[C0]] // CHECK-DAG: %[[D1:.+]] = dim %[[ARG2]], %[[C1]] - // CHECK: %[[origCC:.+]] = flow.dispatch.workgroups[%[[D1]], %[[D0]], %[[C1]]] (%[[ARG2]]) - // CHECK-SAME: %[[ARG3:.+]] : !flow.dispatch.input - // CHECK-SAME: %[[ARG4:.+]] : !flow.dispatch.output + // CHECK: %[[origCC:.+]] = flow.dispatch.workgroups[%[[D1]], %[[D0]], %[[C1]]](%[[ARG2]]) + // CHECK-NEXT: %[[ARG3:.+]]: !flow.dispatch.input + // CHECK-SAME: %[[ARG4:.+]]: !flow.dispatch.output // CHECK: %[[LOAD:.+]] = flow.dispatch.input.load %[[ARG3]] // CHECK: %[[STOREVAL:.+]] = linalg.generic // CHECK-SAME: outs(%[[LOAD]] : tensor) diff --git a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_dynamic.mlir b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_dynamic.mlir index d474ef3b2759..15eb47ee2ab9 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_dynamic.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_dynamic.mlir @@ -12,10 +12,10 @@ func @tensor(%arg0 : tensor, %arg1 : tensor, // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor // CHECK: flow.dispatch.workgroups // CHECK-SAME: (%[[ARG0]], %[[ARG1]], %[[ARG2]]) -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]] : !flow.dispatch.input -// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]] : !flow.dispatch.input -// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]] : !flow.dispatch.input -// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]] : !flow.dispatch.output +// CHECK-NEXT: %[[ARG3:[a-zA-Z0-9_]+]]: !flow.dispatch.input +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: !flow.dispatch.input +// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: !flow.dispatch.input +// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: !flow.dispatch.output // CHECK-DAG: %[[C0:.+]] = constant 0 : index // CHECK-DAG: %[[WGSIZE_X:.+]] = flow.dispatch.workgroup.size[0] // CHECK-DAG: %[[WGSIZE_Y:.+]] = flow.dispatch.workgroup.size[1] @@ -72,12 +72,12 @@ func @generic_op(%A: tensor, %B: tensor) -> tensor { // CHECK-DAG: %[[D0:.+]] = dim %[[ARG0]], %[[C0]] // CHECK-DAG: %[[D1:.+]] = dim %[[ARG0]], %[[C1]] // CHECK: flow.dispatch.workgroups -// CHECK-SAME: [%[[D1]], %[[D0]], %[[C1]]] (%[[ARG0]], %[[ARG1]], %[[D0]], %[[D1]]) -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]] : !flow.dispatch.input -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]] : !flow.dispatch.input -// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]] : index -// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]] : index -// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]] : !flow.dispatch.output +// CHECK-SAME: [%[[D1]], %[[D0]], %[[C1]]](%[[ARG0]], %[[ARG1]], %[[D0]], %[[D1]]) +// CHECK-NEXT: %[[ARG2:[a-zA-Z0-9_]+]]: !flow.dispatch.input +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: !flow.dispatch.input +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: !flow.dispatch.output // CHECK-DAG: %[[LOAD2:.+]] = flow.dispatch.input.load %[[ARG2]] : !flow.dispatch.input // CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor [%[[ARG4]], %[[ARG5]]] // CHECK-DAG: %[[LOAD3:.+]] = flow.dispatch.input.load %[[ARG3]] : !flow.dispatch.input @@ -109,11 +109,11 @@ func @fuse_fill_with_producer(%A : tensor, %B : tensor) -> ten // CHECK: %[[N:.+]] = dim %[[ARG1]], %[[C1]] // CHECK: flow.dispatch.workgroups[%[[N]], %[[M]], %[[C1]]] // CHECK-SAME: (%[[M]], %[[N]], %[[ARG0]], %[[ARG1]]) -// CHECK-SAME: (%[[ARG2:[a-zA-Z0-9_]+]] : index -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]] : index -// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]] : !flow.dispatch.input -// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]] : !flow.dispatch.input -// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]] : !flow.dispatch.output) { +// CHECK-NEXT: (%[[ARG2:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: !flow.dispatch.input +// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: !flow.dispatch.input +// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: !flow.dispatch.output) { // CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32 // CHECK: scf.for // CHECK: scf.for @@ -164,10 +164,10 @@ func @two_dispatches(%A : tensor, %B : tensor) -> tensor -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]] : index -// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]] : index -// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]] : !flow.dispatch.output) { +// CHECK-NEXT: (%[[ARG2:[a-zA-Z0-9_]+]]: !flow.dispatch.input +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: !flow.dispatch.output) { // CHECK: %[[ONE:.+]] = constant 1.0 // CHECK-DAG: %[[INPUT:.+]] = flow.dispatch.input.load %[[ARG2]] // CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor @@ -188,12 +188,12 @@ func @two_dispatches(%A : tensor, %B : tensor) -> tensor -// NOCHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]] : !flow.dispatch.input -// NOCHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]] : !flow.dispatch.input -// NOCHECK-SAME: %[[ARG7:[a-zA-Z0-9_]+]] : !flow.dispatch.output) { +// NOCHECK-SAME: (%[[ARG2:[a-zA-Z0-9_]+]]: index +// NOCHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index +// NOCHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: !flow.dispatch.input +// NOCHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: !flow.dispatch.input +// NOCHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: !flow.dispatch.input +// NOCHECK-SAME: %[[ARG7:[a-zA-Z0-9_]+]]: !flow.dispatch.output) { // NOCHECK: %[[ZERO:.+]] = constant 0.0 // NOCHECK: scf.for // NOCHECK: scf.for @@ -226,9 +226,9 @@ func @dot_general_lower() attributes {iree.module.export} { } // CHECK-LABEL: func @dot_general_lower // CHECK: flow.dispatch.workgroups[%{{.+}}, %{{.+}}, %{{.+}}] -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]] : !flow.dispatch.input<1x1x2xf32> -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]] : !flow.dispatch.input<2x3xf32> -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]] : !flow.dispatch.output<1x3xf32> +// CHECK-NEXT: %[[ARG0:[a-zA-Z0-9_]+]]: !flow.dispatch.input<1x1x2xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: !flow.dispatch.input<2x3xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: !flow.dispatch.output<1x3xf32> // CHECK-DAG: %[[ZERO:.+]] = constant 0.0 // CHECK: %[[LOAD:.+]] = flow.dispatch.input.load %[[ARG0]] // CHECK: %[[RESHAPE:.+]] = linalg.tensor_reshape %[[LOAD]] @@ -260,9 +260,9 @@ func @reshapeop(%arg0: tensor) -> tensor // CHECK-DAG: %[[D1:.+]] = dim %[[ARG0]], %[[C1]] // CHECK: %[[WORKLOAD:.+]] = affine.apply #[[MAP0]]()[%[[D0]], %[[D1]]] // CHECK: %[[RESULT:.+]] = flow.dispatch.workgroups -// CHECK-SAME: [%[[WORKLOAD]], %[[C1]], %[[C1]]] (%[[ARG0]]) -// CHECK-SAME: %[[ARG1:.+]] : !flow.dispatch.input -// CHECK-SAME: %[[ARG2:.+]] : !flow.dispatch.output +// CHECK-SAME: [%[[WORKLOAD]], %[[C1]], %[[C1]]](%[[ARG0]]) +// CHECK-NEXT: %[[ARG1:.+]]: !flow.dispatch.input +// CHECK-SAME: %[[ARG2:.+]]: !flow.dispatch.output // CHECK: %[[LOAD:.+]] = flow.dispatch.input.load %[[ARG1]] // CHECK: %[[RESHAPE:.+]] = linalg.tensor_reshape %[[LOAD]] [#[[MAP1]]] // CHECK: flow.dispatch.output.store %[[RESHAPE]], %[[ARG2]] @@ -364,7 +364,7 @@ func @pad_test(%arg0 : tensor, %arg1 : tensor, %arg2 : index, %4 = affine.apply affine_map<(d0)[s0, s1] -> (d0 + s0 + s1)>(%2)[%arg3, %arg5] %5 = linalg.init_tensor [%3, %4] : tensor %6 = linalg.fill(%5, %0) : tensor, f32 -> tensor - %7 = flow.tensor.update %arg0, %6[%arg2, %arg3] : tensor -> tensor + %7 = flow.tensor.update %arg0, %6[%arg2, %arg3] : tensor{%1, %2} -> tensor{%3, %4} return %7 : tensor } diff --git a/iree/compiler/Dialect/Flow/Transforms/test/fold_compatible_dispatch_regions.mlir b/iree/compiler/Dialect/Flow/Transforms/test/fold_compatible_dispatch_regions.mlir index 292b28052dd3..d7ec93eaf468 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/fold_compatible_dispatch_regions.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/fold_compatible_dispatch_regions.mlir @@ -11,7 +11,7 @@ func @noFolding(%arg0 : tensor<4xf32>) -> tensor<4xf32> { // CHECK-LABEL: func @noFolding // CHECK-NEXT: %[[WORKLOAD0:.+]] = constant 4 : index -// CHECK-NEXT: %0 = flow.dispatch.region[%[[WORKLOAD0]] : index](%arg1 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { +// CHECK-NEXT: %0 = flow.dispatch.region[%[[WORKLOAD0]] : index](%arg1 = %arg0 : tensor<4xf32>) -> (tensor<4xf32>) { // CHECK-NEXT: %1 = mhlo.add %arg1, %arg1 : tensor<4xf32> // CHECK-NEXT: flow.return %1 : tensor<4xf32> // CHECK-NEXT: } @@ -38,7 +38,7 @@ func @elementwiseOps(%arg0 : tensor<4xf32>) -> tensor<4xf32> { // CHECK-LABEL: func @elementwiseOps // CHECK: %[[WORKLOAD0:.+]] = constant 4 -// CHECK: %[[R0:.+]] = flow.dispatch.region[%[[WORKLOAD0]] : index](%arg1 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { +// CHECK: %[[R0:.+]] = flow.dispatch.region[%[[WORKLOAD0]] : index](%arg1 = %arg0 : tensor<4xf32>) -> (tensor<4xf32>) { // CHECK-NEXT: %1 = mhlo.add %arg1, %arg1 : tensor<4xf32> // CHECK-NEXT: %2 = mhlo.subtract %1, %arg1 : tensor<4xf32> // CHECK-NEXT: %3 = mhlo.multiply %arg1, %2 : tensor<4xf32> @@ -69,17 +69,17 @@ func @interleavedDot(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { // CHECK-LABEL: func @interleavedDot // CHECK-NEXT: %[[WORKLOAD0:.+]] = constant 16 : index -// CHECK-NEXT: %[[R0:.+]] = flow.dispatch.region[%[[WORKLOAD0]] : index](%arg1 = %arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { +// CHECK-NEXT: %[[R0:.+]] = flow.dispatch.region[%[[WORKLOAD0]] : index](%arg1 = %arg0 : tensor<4x4xf32>) -> (tensor<4x4xf32>) { // CHECK-NEXT: %3 = mhlo.add %arg1, %arg1 : tensor<4x4xf32> // CHECK-NEXT: flow.return %3 : tensor<4x4xf32> // CHECK-NEXT: } // CHECK-NEXT: %[[WORKLOAD1:.+]] = constant 16 : index -// CHECK-NEXT: %[[R1:.+]] = flow.dispatch.region[%[[WORKLOAD1]] : index](%arg1 = %[[R0]] : tensor<4x4xf32>, %arg2 = %arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { +// CHECK-NEXT: %[[R1:.+]] = flow.dispatch.region[%[[WORKLOAD1]] : index](%arg1 = %[[R0]] : tensor<4x4xf32>, %arg2 = %arg0 : tensor<4x4xf32>) -> (tensor<4x4xf32>) { // CHECK-NEXT: %3 = "mhlo.dot"(%arg1, %arg2) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> // CHECK-NEXT: flow.return %3 : tensor<4x4xf32> // CHECK-NEXT: } // CHECK-NEXT: %[[WORKLOAD2:.+]] = constant 16 : index -// CHECK-NEXT: %[[R2:.+]] = flow.dispatch.region[%[[WORKLOAD2]] : index](%arg1 = %[[R1]] : tensor<4x4xf32>, %arg2 = %arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { +// CHECK-NEXT: %[[R2:.+]] = flow.dispatch.region[%[[WORKLOAD2]] : index](%arg1 = %[[R1]] : tensor<4x4xf32>, %arg2 = %arg0 : tensor<4x4xf32>) -> (tensor<4x4xf32>) { // CHECK-NEXT: %3 = mhlo.multiply %arg1, %arg2 : tensor<4x4xf32> // CHECK-NEXT: flow.return %3 : tensor<4x4xf32> // CHECK-NEXT: } diff --git a/iree/compiler/Dialect/Flow/Transforms/test/form_streams.mlir b/iree/compiler/Dialect/Flow/Transforms/test/form_streams.mlir index 8796e2431222..927f1b5b9363 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/form_streams.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/form_streams.mlir @@ -1,22 +1,18 @@ -// RUN: iree-opt -split-input-file -iree-flow-form-streams %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -iree-flow-form-streams -cse -canonicalize %s | IreeFileCheck %s // CHECK-LABEL: func @outsideTieShape func @outsideTieShape(%arg0: tensor {iree.reflection = {}}, %arg1: !shapex.ranked_shape<[?]> {iree.reflection = {}}) -> (tensor {iree.reflection = {}}) attributes {iree.module.export} { - // CHECK: %[[WORKLOAD0:.+]] = constant 0 : index %c0 = constant 0 : index - // CHECK-NEXT: %0 = shapex.tie_shape %arg0, %arg1 : tensor, !shapex.ranked_shape<[?]> - %2 = shapex.tie_shape %arg0, %arg1 : tensor, !shapex.ranked_shape<[?]> - // CHECK-NEXT: %[[WORKLOAD1:.+]] = constant 1 : index - %c1 = constant 1 : index - // CHECK-NEXT: %1 = shapex.ranked_dim %arg1[0] : !shapex.ranked_shape<[?]> -> index - // CHECK-NEXT: %2 = flow.ex.stream.fragment(%arg2 = %arg0 : tensor, %arg3 = %1 : index, %arg4 = %[[WORKLOAD0]] : index, %arg5 = %0 : tensor) -> tensor { - // CHECK-NEXT: %3 = shapex.make_ranked_shape %arg3 : (index) -> !shapex.ranked_shape<[?]> - // CHECK-NEXT: %4 = shapex.tie_shape %arg2, %3 : tensor, !shapex.ranked_shape<[?]> - // CHECK-NEXT: %5 = flow.dispatch @main_ex_dispatch_1::@main_ex_dispatch_1[%arg4] (%arg4, %4) : (index, tensor) -> tensor - // CHECK-NEXT: flow.return %5 : tensor + // CHECK-DAG: %[[DIM:.+]] = shapex.ranked_dim %arg1[0] + %dim = shapex.ranked_dim %arg1[0] : !shapex.ranked_shape<[?]> -> index + // CHECK-NEXT: %[[RET:.+]] = flow.ex.stream.fragment(%[[DIM]], %arg0) : (index, tensor{%[[DIM]]}) -> (tensor{%[[DIM]]}) = + // CHECK-NEXT: (%[[INNER_DIM:.+]]: index, %[[CAPTURE:.+]]: tensor) -> (tensor) { + // CHECK-NEXT: %[[WORKLOAD0:.+]] = constant 0 : index + // CHECK-NEXT: %[[INNER_RET:.+]] = flow.dispatch @main_ex_dispatch_1::@main_ex_dispatch_1[%[[WORKLOAD0]]](%[[INNER_DIM]], %[[CAPTURE]]) : (index, tensor{%[[INNER_DIM]]}) -> (tensor{%[[INNER_DIM]]}) + // CHECK-NEXT: flow.return %[[INNER_RET]] : tensor // CHECK-NEXT: } - %15 = flow.dispatch @main_ex_dispatch_1::@main_ex_dispatch_1[%c0] (%c0, %2) : (index, tensor) -> tensor - // CHECK-NEXT: return %2 : tensor + %15 = flow.dispatch @main_ex_dispatch_1::@main_ex_dispatch_1[%c0](%dim, %arg0) : (index, tensor{%dim}) -> tensor{%dim} + // CHECK-NEXT: return %[[RET]] : tensor return %15 : tensor } @@ -35,15 +31,16 @@ flow.executable @outerOps_ex_dispatch_0 { } // CHECK-LABEL: func @outerOps func @outerOps(%arg0: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: %[[WORKLOAD0:.+]] = constant 4 : index - %cst = constant 4 : index - // CHECK-NEXT: %0 = addf %arg0, %arg0 : tensor<4xf32> + // CHECK: %0 = addf %arg0, %arg0 : tensor<4xf32> %0 = addf %arg0, %arg0 : tensor<4xf32> - // CHECK-NEXT: %1 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD0]] : index, %arg2 = %0 : tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NEXT: %3 = flow.dispatch @outerOps_ex_dispatch_0::@outerOps_rgn_dispatch_0[%arg1] (%arg2) : (tensor<4xf32>) -> tensor<4xf32> + %cst = constant 4 : index + // CHECK-NEXT: %1 = flow.ex.stream.fragment(%0) : (tensor<4xf32>) -> (tensor<4xf32>) = + // CHECK-NEXT: (%[[INNER_ARG:.+]]: tensor<4xf32>) -> (tensor<4xf32>) { + // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 : index + // CHECK-NEXT: %3 = flow.dispatch @outerOps_ex_dispatch_0::@outerOps_rgn_dispatch_0[%[[WORKLOAD]]](%[[INNER_ARG]]) : (tensor<4xf32>) -> (tensor<4xf32>) // CHECK-NEXT: flow.return %3 : tensor<4xf32> // CHECK-NEXT: } - %1 = flow.dispatch @outerOps_ex_dispatch_0::@outerOps_rgn_dispatch_0[%cst] (%0) : (tensor<4xf32>) -> tensor<4xf32> + %1 = flow.dispatch @outerOps_ex_dispatch_0::@outerOps_rgn_dispatch_0[%cst](%0) : (tensor<4xf32>) -> tensor<4xf32> // CHECK: %2 = addf %1, %1 : tensor<4xf32> %2 = addf %1, %1 : tensor<4xf32> // CHECK-NEXT: return %2 : tensor<4xf32> @@ -54,15 +51,16 @@ func @outerOps(%arg0: tensor<4xf32>) -> tensor<4xf32> { // CHECK-LABEL: func @nondependentOuterOps( func @nondependentOuterOps(%arg0: tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 : index %cst = constant 4 : index // CHECK-NEXT: %[[ADD1:.+]] = addf %arg0, %arg0 : tensor<4xf32> %add1 = addf %arg0, %arg0 : tensor<4xf32> - // CHECK-NEXT: %[[S:.+]] = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD]] : index, %arg2 = %arg0 : tensor<4xf32>, %arg3 = %[[ADD1]] : tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NEXT: %[[D1:.+]] = flow.dispatch @dispatch_1::@dispatch_1[%arg1] (%arg2, %arg2) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - %d1 = flow.dispatch @dispatch_1::@dispatch_1[%cst] (%arg0, %arg0) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - // CHECK-NEXT: %[[D2:.+]] = flow.dispatch @dispatch_2::@dispatch_2[%arg1] (%[[D1]], %arg3) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - %d2 = flow.dispatch @dispatch_2::@dispatch_2[%cst] (%d1, %add1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %[[S:.+]] = flow.ex.stream.fragment(%arg0, %[[ADD1]]) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>) = + // CHECK-NEXT: (%arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) { + // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 : index + // CHECK-NEXT: %[[D1:.+]] = flow.dispatch @dispatch_1::@dispatch_1[%[[WORKLOAD]]](%arg1, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>) + %d1 = flow.dispatch @dispatch_1::@dispatch_1[%cst](%arg0, %arg0) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %[[D2:.+]] = flow.dispatch @dispatch_2::@dispatch_2[%[[WORKLOAD]]](%[[D1]], %arg2) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>) + %d2 = flow.dispatch @dispatch_2::@dispatch_2[%cst](%d1, %add1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: flow.return %[[D2]] : tensor<4xf32> // CHECK-NEXT: } // CHECK-NEXT: %[[ADD2:.+]] = addf %[[S]], %arg0 : tensor<4xf32> @@ -86,20 +84,23 @@ flow.executable @interleavedOuterOps_ex_dispatch_0 { } // CHECK-LABEL: func @interleavedOuterOps( func @interleavedOuterOps(%arg0: tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 : index %cst = constant 4 : index - // CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD]] : index, %arg2 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NEXT: %3 = flow.dispatch @interleavedOuterOps_ex_dispatch_0::@interleavedOuterOps_rgn_dispatch_0[%arg1] (%arg2) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg0) : (tensor<4xf32>) -> (tensor<4xf32>) = + // CHECK-NEXT: (%arg1: tensor<4xf32>) -> (tensor<4xf32>) + // CHECK-NEXT: %[[WORKLOAD1:.+]] = constant 4 : index + // CHECK-NEXT: %3 = flow.dispatch @interleavedOuterOps_ex_dispatch_0::@interleavedOuterOps_rgn_dispatch_0[%[[WORKLOAD1]]](%arg1) : (tensor<4xf32>) -> (tensor<4xf32>) // CHECK-NEXT: flow.return %3 : tensor<4xf32> // CHECK-NEXT: } - %0 = flow.dispatch @interleavedOuterOps_ex_dispatch_0::@interleavedOuterOps_rgn_dispatch_0[%cst] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> + %0 = flow.dispatch @interleavedOuterOps_ex_dispatch_0::@interleavedOuterOps_rgn_dispatch_0[%cst](%arg0) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: %1 = addf %0, %0 : tensor<4xf32> %1 = addf %0, %0 : tensor<4xf32> - // CHECK-NEXT: %2 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD]] : index, %arg2 = %1 : tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NEXT: %3 = flow.dispatch @interleavedOuterOps_ex_dispatch_0::@interleavedOuterOps_rgn_dispatch_0[%arg1] (%arg2) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %2 = flow.ex.stream.fragment(%1) : (tensor<4xf32>) -> (tensor<4xf32>) = + // CHECK-NEXT: (%arg1: tensor<4xf32>) -> (tensor<4xf32>) { + // CHECK-NEXT: %[[WORKLOAD2:.+]] = constant 4 : index + // CHECK-NEXT: %3 = flow.dispatch @interleavedOuterOps_ex_dispatch_0::@interleavedOuterOps_rgn_dispatch_0[%[[WORKLOAD2]]](%arg1) : (tensor<4xf32>) -> (tensor<4xf32>) // CHECK-NEXT: flow.return %3 : tensor<4xf32> // CHECK-NEXT: } - %2 = flow.dispatch @interleavedOuterOps_ex_dispatch_0::@interleavedOuterOps_rgn_dispatch_0[%cst] (%1) : (tensor<4xf32>) -> tensor<4xf32> + %2 = flow.dispatch @interleavedOuterOps_ex_dispatch_0::@interleavedOuterOps_rgn_dispatch_0[%cst](%1) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: return %2 : tensor<4xf32> return %2 : tensor<4xf32> } @@ -116,13 +117,14 @@ flow.executable @independentOps_ex_dispatch_0 { } // CHECK-LABEL: func @independentOps( func @independentOps(%arg0: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { - // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 : index %cst = constant 4 : index - // CHECK-NEXT: %0:2 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD]] : index, %arg2 = %arg0 : tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { - // CHECK-DAG: = flow.dispatch @independentOps_ex_dispatch_0::@independentOps_rgn_dispatch_0[%arg1] (%arg2) - %0 = flow.dispatch @independentOps_ex_dispatch_0::@independentOps_rgn_dispatch_0[%cst] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> - // CHECK-DAG: = flow.dispatch @independentOps_ex_dispatch_0::@independentOps_rgn_dispatch_1[%arg1] (%arg2) - %1 = flow.dispatch @independentOps_ex_dispatch_0::@independentOps_rgn_dispatch_1[%cst] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %0:2 = flow.ex.stream.fragment(%arg0) : (tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) = + // CHECK-NEXT: (%arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { + // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 : index + // CHECK-DAG: = flow.dispatch @independentOps_ex_dispatch_0::@independentOps_rgn_dispatch_0[%[[WORKLOAD]]](%arg1) + %0 = flow.dispatch @independentOps_ex_dispatch_0::@independentOps_rgn_dispatch_0[%cst](%arg0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK-DAG: = flow.dispatch @independentOps_ex_dispatch_0::@independentOps_rgn_dispatch_1[%[[WORKLOAD]]](%arg1) + %1 = flow.dispatch @independentOps_ex_dispatch_0::@independentOps_rgn_dispatch_1[%cst](%arg0) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: flow.return %{{.+}}, %{{.+}} // CHECK-NEXT: } // CHECK-NEXT: return %{{.+}}, %{{.+}} @@ -166,17 +168,18 @@ flow.executable @interleavedDot_ex_dispatch_2 { } // CHECK-LABEL: func @interleavedDot( func @interleavedDot(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { - // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 16 : index %cst = constant 16 : index - // CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD]] : index, %arg2 = %arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { - // CHECK-NEXT: %1 = flow.dispatch @interleavedDot_ex_dispatch_0::@interleavedDot_rgn_dispatch_0[%arg1] (%arg2) : (tensor<4x4xf32>) -> tensor<4x4xf32> - // CHECK-NEXT: %2 = flow.dispatch @interleavedDot_ex_dispatch_1::@interleavedDot_rgn_dispatch_1[%arg1] (%1, %arg2) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> - // CHECK-NEXT: %3 = flow.dispatch @interleavedDot_ex_dispatch_2::@interleavedDot_rgn_dispatch_2[%arg1] (%2, %arg2) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + // CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg0) : (tensor<4x4xf32>) -> (tensor<4x4xf32>) = + // CHECK-NEXT: (%arg1: tensor<4x4xf32>) -> (tensor<4x4xf32>) { + // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 16 : index + // CHECK-NEXT: %1 = flow.dispatch @interleavedDot_ex_dispatch_0::@interleavedDot_rgn_dispatch_0[%[[WORKLOAD]]](%arg1) : (tensor<4x4xf32>) -> (tensor<4x4xf32>) + // CHECK-NEXT: %2 = flow.dispatch @interleavedDot_ex_dispatch_1::@interleavedDot_rgn_dispatch_1[%[[WORKLOAD]]](%1, %arg1) : (tensor<4x4xf32>, tensor<4x4xf32>) -> (tensor<4x4xf32>) + // CHECK-NEXT: %3 = flow.dispatch @interleavedDot_ex_dispatch_2::@interleavedDot_rgn_dispatch_2[%[[WORKLOAD]]](%2, %arg1) : (tensor<4x4xf32>, tensor<4x4xf32>) -> (tensor<4x4xf32>) // CHECK-NEXT: flow.return %3 : tensor<4x4xf32> // CHECK-NEXT: } - %0 = flow.dispatch @interleavedDot_ex_dispatch_0::@interleavedDot_rgn_dispatch_0[%cst] (%arg0) : (tensor<4x4xf32>) -> tensor<4x4xf32> - %1 = flow.dispatch @interleavedDot_ex_dispatch_1::@interleavedDot_rgn_dispatch_1[%cst] (%0, %arg0) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> - %2 = flow.dispatch @interleavedDot_ex_dispatch_2::@interleavedDot_rgn_dispatch_2[%cst] (%1, %arg0) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + %0 = flow.dispatch @interleavedDot_ex_dispatch_0::@interleavedDot_rgn_dispatch_0[%cst](%arg0) : (tensor<4x4xf32>) -> tensor<4x4xf32> + %1 = flow.dispatch @interleavedDot_ex_dispatch_1::@interleavedDot_rgn_dispatch_1[%cst](%0, %arg0) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + %2 = flow.dispatch @interleavedDot_ex_dispatch_2::@interleavedDot_rgn_dispatch_2[%cst](%1, %arg0) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> // CHECK-NEXT: return %0 : tensor<4x4xf32> return %2 : tensor<4x4xf32> } @@ -207,20 +210,23 @@ flow.executable @caller_ex_dispatch_1 { } // CHECK-LABEL: func @caller( func @caller(%arg0: tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 : index %cst = constant 4 : index - // CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD]] : index, %arg2 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NEXT: %3 = flow.dispatch @caller_ex_dispatch_0::@caller_rgn_dispatch_0[%arg1] (%arg2) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg0) : (tensor<4xf32>) -> (tensor<4xf32>) = + // CHECK-NEXT: (%arg1: tensor<4xf32>) -> (tensor<4xf32>) { + // CHECK-NEXT: %[[WORKLOAD1:.+]] = constant 4 : index + // CHECK-NEXT: %3 = flow.dispatch @caller_ex_dispatch_0::@caller_rgn_dispatch_0[%[[WORKLOAD1]]](%arg1) : (tensor<4xf32>) -> (tensor<4xf32>) // CHECK-NEXT: flow.return %3 : tensor<4xf32> // CHECK-NEXT: } - %0 = flow.dispatch @caller_ex_dispatch_0::@caller_rgn_dispatch_0[%cst] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> + %0 = flow.dispatch @caller_ex_dispatch_0::@caller_rgn_dispatch_0[%cst](%arg0) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: %1 = call @callee(%0) : (tensor<4xf32>) -> tensor<4xf32> %1 = call @callee(%0) : (tensor<4xf32>) -> tensor<4xf32> - // CHECK-NEXT: %2 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD]] : index, %arg2 = %arg0 : tensor<4xf32>, %arg3 = %1 : tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NEXT: %3 = flow.dispatch @caller_ex_dispatch_1::@caller_rgn_dispatch_1[%arg1] (%arg2, %arg3) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %2 = flow.ex.stream.fragment(%arg0, %1) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>) = + // CHECK-NEXT: (%arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) { + // CHECK-NEXT: %[[WORKLOAD2:.+]] = constant 4 : index + // CHECK-NEXT: %3 = flow.dispatch @caller_ex_dispatch_1::@caller_rgn_dispatch_1[%[[WORKLOAD2]]](%arg1, %arg2) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>) // CHECK-NEXT: flow.return %3 : tensor<4xf32> // CHECK-NEXT: } - %2 = flow.dispatch @caller_ex_dispatch_1::@caller_rgn_dispatch_1[%cst] (%arg0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %2 = flow.dispatch @caller_ex_dispatch_1::@caller_rgn_dispatch_1[%cst](%arg0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: return %2 : tensor<4xf32> return %2 : tensor<4xf32> } @@ -235,43 +241,37 @@ flow.executable @callee_ex_dispatch_0 { } // CHECK-LABEL: func @callee( func @callee(%arg0: tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 : index %cst = constant 4 : index - // CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD]] : index, %arg2 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NEXT: %1 = flow.dispatch @callee_ex_dispatch_0::@callee_rgn_dispatch_0[%arg1] (%arg2) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg0) : (tensor<4xf32>) -> (tensor<4xf32>) = + // CHECK-NEXT: (%arg1: tensor<4xf32>) -> (tensor<4xf32>) { + // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 : index + // CHECK-NEXT: %1 = flow.dispatch @callee_ex_dispatch_0::@callee_rgn_dispatch_0[%[[WORKLOAD]]](%arg1) : (tensor<4xf32>) -> (tensor<4xf32>) // CHECK-NEXT: flow.return %1 : tensor<4xf32> // CHECK-NEXT: } - %0 = flow.dispatch @callee_ex_dispatch_0::@callee_rgn_dispatch_0[%cst] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> + %0 = flow.dispatch @callee_ex_dispatch_0::@callee_rgn_dispatch_0[%cst](%arg0) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: return %0 : tensor<4xf32> return %0 : tensor<4xf32> } // CHECK-LABEL: @simple_unary -// CHECK-SAME: %[[A0:[^:[:space:]]+]]: tensor -// CHECK-SAME: %[[A1:[^:[:space:]]+]]: !shapex.ranked_shape<[?,?]> +// CHECK-SAME: %[[A0:.+]]: tensor +// CHECK-SAME: %[[A1:.+]]: !shapex.ranked_shape<[?,?]> func @simple_unary(%arg0: tensor, %arg1: !shapex.ranked_shape<[?,?]>) -> (tensor, !shapex.ranked_shape<[?,?]>) { - %0 = shapex.ranked_dim %arg1[0] : !shapex.ranked_shape<[?,?]> -> index - %1 = shapex.ranked_dim %arg1[1] : !shapex.ranked_shape<[?,?]> -> index - %2 = muli %0, %1 : index - %4 = shapex.ranked_dim %arg1[0] : !shapex.ranked_shape<[?,?]> -> index - %5 = shapex.ranked_dim %arg1[1] : !shapex.ranked_shape<[?,?]> -> index - %3 = shapex.tie_shape %arg0, %arg1 : tensor, !shapex.ranked_shape<[?,?]> + // CHECK-DAG: %[[DIM0:.+]] = shapex.ranked_dim %arg1[0] + %dim0 = shapex.ranked_dim %arg1[0] : !shapex.ranked_shape<[?,?]> -> index + // CHECK-DAG: %[[DIM1:.+]] = shapex.ranked_dim %arg1[1] + %dim1 = shapex.ranked_dim %arg1[1] : !shapex.ranked_shape<[?,?]> -> index + // CHECK: %[[SZ:.+]] = muli + %2 = muli %dim0, %dim1 : index // Verify that the fragment captures the tie_shapes and marshals the indices // in as loose index values (not as ranked_shape types). - // CHECK: %[[S:.+]] = flow.ex.stream.fragment - // CHECK-SAME: %[[STREAM_A0:[^:[:space:]]+]] = %[[A0]] : tensor, - // CHECK-SAME: %[[STREAM_A1:[^:[:space:]]+]] = %[[UNUSED0:[^:[:space:]]+]] : index, - // CHECK-SAME: %[[STREAM_A2:[^:[:space:]]+]] = %[[UNUSED1:[^:[:space:]]+]] : index, - // CHECK-SAME: { - // CHECK: %[[STREAM_RS0:.+]] = shapex.make_ranked_shape %[[STREAM_A1]], %[[STREAM_A2]] - // CHECK: %[[STREAM_R0:.+]] = shapex.tie_shape %[[STREAM_A0]], %[[STREAM_RS0]] - // CHECK: %[[STREAM_R1:.+]] = flow.dispatch @simple_unary_ex_dispatch_0 - // CHECK: %[[STREAM_R2:.+]] = shapex.tie_shape %[[STREAM_R1]], %[[STREAM_RS0]] - // CHECK: return %[[STREAM_R2]] + // CHECK: %[[S:.+]] = flow.ex.stream.fragment(%[[SZ]], %[[A0]], %[[DIM0]], %[[DIM1]]) : (index, tensor{%[[DIM0]], %[[DIM1]]}, index, index) -> (tensor{%[[DIM0]], %[[DIM1]]}) = + // CHECK: (%arg2: index, %arg3: tensor, %arg4: index, %arg5: index) -> (tensor) { + // CHECK: %[[STREAM_RET:.+]] = flow.dispatch @simple_unary_ex_dispatch_0{{.+}}[%arg2](%arg3, %arg4, %arg5) : (tensor{%arg4, %arg5}, index, index) -> (tensor{%arg4, %arg5}) + // CHECK: return %[[STREAM_RET]] // CHECK: } - %6 = flow.dispatch @simple_unary_ex_dispatch_0::@simple_unary_ex_dispatch_0[%2] (%3, %4, %5) : (tensor, index, index) -> tensor - %7 = shapex.tie_shape %6, %arg1 : tensor, !shapex.ranked_shape<[?,?]> - return %7, %arg1 : tensor, !shapex.ranked_shape<[?,?]> + %3 = flow.dispatch @simple_unary_ex_dispatch_0::@simple_unary_ex_dispatch_0[%2](%arg0, %dim0, %dim1) : (tensor{%dim0, %dim1}, index, index) -> tensor{%dim0, %dim1} + return %3, %arg1 : tensor, !shapex.ranked_shape<[?,?]> } @@ -279,18 +279,16 @@ func @simple_unary(%arg0: tensor, %arg1: !shapex.ranked_shape<[?,?]>) - // CHECK-LABEL: @bad_input_ordering func @bad_input_ordering() -> (tensor, tensor) { - // CHECK: %[[W:.+]] = constant 1 : index - %workload = constant 1 : index // CHECK: %[[S:.+]] = flow.ex.stream.fragment - // CHECK-SAME: { - // CHECK-DAG: %[[D1:.+]] = flow.dispatch @dispatch_1::@dispatch_1 - // CHECK: flow.return - // CHECK-NEXT: } - %0 = flow.dispatch @dispatch_1::@dispatch_1[%workload] () : () -> tensor - // CHECK: %[[C2:.+]] = constant 2 : i32 - %c2 = constant 2 : i32 + // CHECK: = constant 1 : index + // CHECK: %[[D1:.+]] = flow.dispatch @dispatch_1::@dispatch_1 + %workload = constant 1 : index + %0 = flow.dispatch @dispatch_1::@dispatch_1[%workload]() : () -> tensor + // CHECK: %[[C2:.+]] = constant 2 : i32 // CHECK-DAG: %[[D2:.+]] = flow.dispatch @dispatch_2::@dispatch_2 - %1 = flow.dispatch @dispatch_2::@dispatch_2[%workload] (%c2) : (i32) -> tensor + %c2 = constant 2 : i32 + %1 = flow.dispatch @dispatch_2::@dispatch_2[%workload](%c2) : (i32) -> tensor + // CHECK: flow.return return %0, %1 : tensor, tensor } @@ -298,24 +296,20 @@ func @bad_input_ordering() -> (tensor, tensor) { // CHECK-LABEL: @interstream_readback func @interstream_readback() -> (tensor, tensor, tensor<2xf32>) { - // CHECK: %[[W:.+]] = constant 1 : index %w = constant 1 : index // CHECK: %[[S1:.+]]:2 = flow.ex.stream.fragment - // CHECK-SAME: { - // CHECK: %[[D1:.+]] = flow.dispatch @dispatch_1::@dispatch_1 - // CHECK-DAG: %[[D2:.+]] = flow.dispatch @dispatch_2::@dispatch_2 - // Could be returned in either order - // CHECK-NEXT: flow.return - // CHECK-NEXT: } - %d1 = flow.dispatch @dispatch_1::@dispatch_1[%w] () : () -> tensor - %d2 = flow.dispatch @dispatch_2::@dispatch_2[%w] () : () -> tensor + // CHECK: %[[W:.+]] = constant 1 : index + // CHECK: %[[D1:.+]] = flow.dispatch @dispatch_1::@dispatch_1 + // CHECK-DAG: %[[D2:.+]] = flow.dispatch @dispatch_2::@dispatch_2 + // Could be returned in either order + // CHECK-NEXT: flow.return + %d1 = flow.dispatch @dispatch_1::@dispatch_1[%w]() : () -> tensor + %d2 = flow.dispatch @dispatch_2::@dispatch_2[%w]() : () -> tensor // CHECK: %[[READBACK:.+]] = flow.tensor.load %[[S1]] %readback = flow.tensor.load %d1 : tensor // CHECK: %[[S2:.+]] = flow.ex.stream.fragment - // CHECK-SAME: { // CHECK-DAG: %[[D3:.+]] = flow.dispatch @dispatch_3::@dispatch_3 // CHECK: flow.return %[[D3]] - // CHECK-NEXT: } %d3 = flow.dispatch @dispatch_3::@dispatch_3[%w] (%readback) : (i32) -> tensor<2xf32> // CHECK: return %[[S1]]# // CHECK-SAME: %[[S1]]# @@ -325,31 +319,29 @@ func @interstream_readback() -> (tensor, tensor, tensor<2xf32>) { } // ----- + // CHECK-LABEL: @ordering func @ordering(%w : index) -> (tensor, tensor, tensor) { - // CHECK: %[[C1:.+]] = constant 1 %c1 = constant 1 : i32 // CHECK: %[[S1:.+]] = flow.ex.stream.fragment - // CHECK-SAME: { - // CHECK-DAG: %[[D1:.+]] = flow.dispatch @dispatch_1::@dispatch_1 - // CHECK-NEXT: flow.return %[[D1]] - // CHECK-NEXT: } - %d1 = flow.dispatch @dispatch_1::@dispatch_1[%w] (%c1) : (i32) -> (tensor) + // CHECK: %[[C1:.+]] = constant 1 + // CHECK-DAG: %[[D1:.+]] = flow.dispatch @dispatch_1::@dispatch_1 + // CHECK-NEXT: flow.return %[[D1]] + %d1 = flow.dispatch @dispatch_1::@dispatch_1[%w](%c1) : (i32) -> (tensor) // CHECK: %[[SE_USER:.+]] = iree.do_not_optimize(%[[S1]]) %side_effecting_user = iree.do_not_optimize(%d1) : tensor - // CHECK: %[[C2:.+]] = constant 2 %c2 = constant 2 : i32 // CHECK: %[[S2:.+]] = flow.ex.stream.fragment - // CHECK-SAME: { - // CHECK-DAG: %[[D2:.+]] = flow.dispatch @dispatch_2::@dispatch_2 - // CHECK-NEXT: flow.return %[[D2]] - // CHECK-NEXT: } - %d2 = flow.dispatch @dispatch_2::@dispatch_2[%w] (%c2) : (i32) -> (tensor) + // CHECK: %[[C2:.+]] = constant 2 + // CHECK-DAG: %[[D2:.+]] = flow.dispatch @dispatch_2::@dispatch_2 + // CHECK-NEXT: flow.return %[[D2]] + %d2 = flow.dispatch @dispatch_2::@dispatch_2[%w](%c2) : (i32) -> (tensor) // CHECK: return %[[S1]], %[[S2]], %[[SE_USER]] return %d1, %d2, %side_effecting_user : tensor, tensor, tensor } // ----- + // CHECK-LABEL: @metadata_only func @metadata_only(%t: tensor) -> (tensor, !shapex.ranked_shape<[?]>) { // CHECK-NOT: flow.ex.stream.fragment diff --git a/iree/compiler/Dialect/Flow/Transforms/test/hoist_unstreamable_ops.mlir b/iree/compiler/Dialect/Flow/Transforms/test/hoist_unstreamable_ops.mlir index 1872945f6633..d94b95aa290e 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/hoist_unstreamable_ops.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/hoist_unstreamable_ops.mlir @@ -8,24 +8,24 @@ func @constants() { // CHECK-DAG: constant 4 : index // CHECK-DAG: constant 5 : index // CHECK-DAG: constant 6 : index - // CHECK: flow.dispatch @dispatch0::@dispatch0[%[[W]]] () : () -> tensor - // CHECK: flow.dispatch @dispatch1::@dispatch1[%[[W]]] () : () -> tensor - // CHECK: flow.dispatch @dispatch2::@dispatch2[%[[W]]] () : () -> tensor - // CHECK: flow.dispatch @dispatch3::@dispatch3[%[[W]]] () : () -> tensor - // CHECK: flow.dispatch @dispatch4::@dispatch4[%[[W]]] () : () -> tensor - // CHECK: flow.dispatch @dispatch5::@dispatch5[%[[W]]] () : () -> tensor + // CHECK: flow.dispatch @dispatch0::@dispatch0[%[[W]]]() : () -> (tensor) + // CHECK: flow.dispatch @dispatch1::@dispatch1[%[[W]]]() : () -> (tensor) + // CHECK: flow.dispatch @dispatch2::@dispatch2[%[[W]]]() : () -> (tensor) + // CHECK: flow.dispatch @dispatch3::@dispatch3[%[[W]]]() : () -> (tensor) + // CHECK: flow.dispatch @dispatch4::@dispatch4[%[[W]]]() : () -> (tensor) + // CHECK: flow.dispatch @dispatch5::@dispatch5[%[[W]]]() : () -> (tensor) %w = constant 1 : index - %d0 = flow.dispatch @dispatch0::@dispatch0[%w] () : () -> (tensor) + %d0 = flow.dispatch @dispatch0::@dispatch0[%w]() : () -> tensor %c2 = constant 2 : index - %d1 = flow.dispatch @dispatch1::@dispatch1[%w] () : () -> (tensor) + %d1 = flow.dispatch @dispatch1::@dispatch1[%w]() : () -> tensor %c3 = constant 3 : index - %d2 = flow.dispatch @dispatch2::@dispatch2[%w] () : () -> (tensor) + %d2 = flow.dispatch @dispatch2::@dispatch2[%w]() : () -> tensor %c4 = constant 4 : index - %d3 = flow.dispatch @dispatch3::@dispatch3[%w] () : () -> (tensor) + %d3 = flow.dispatch @dispatch3::@dispatch3[%w]() : () -> tensor %c5 = constant 5 : index - %d4 = flow.dispatch @dispatch4::@dispatch4[%w] () : () -> (tensor) + %d4 = flow.dispatch @dispatch4::@dispatch4[%w]() : () -> tensor %c6 = constant 6 : index - %d5 = flow.dispatch @dispatch5::@dispatch5[%w] () : () -> (tensor) + %d5 = flow.dispatch @dispatch5::@dispatch5[%w]() : () -> tensor return } @@ -38,16 +38,12 @@ func @dynamic_tensor(%input: tensor, %shape: !shapex.ranked_shape<[?,?] // CHECK-DAG: %[[W:.+]] = constant 1 // CHECK-DAG: %[[DIM0:.+]] shapex.ranked_dim %[[SHAPE]][0] // CHECK-DAG: %[[DIM1:.+]] shapex.ranked_dim %[[SHAPE]][1] - // CHECK: %[[TIE0:.+]] = shapex.tie_shape %[[INPUT]], %[[SHAPE]] // CHECK: %[[D:.+]] = flow.dispatch - // CHECK: %[[TIE1:.+]] = shapex.tie_shape %[[D]], %[[SHAPE]] %w = constant 1 : index - %tie0 = shapex.tie_shape %input, %shape : tensor, !shapex.ranked_shape<[?,?]> %dim0 = shapex.ranked_dim %shape[0] : !shapex.ranked_shape<[?,?]> -> index %dim1 = shapex.ranked_dim %shape[1] : !shapex.ranked_shape<[?,?]> -> index - %d = flow.dispatch @dispatch::@dispatch[%w] (%tie0, %dim0, %dim1) : (tensor, index, index) -> tensor - %tie1 = shapex.tie_shape %d, %shape : tensor, !shapex.ranked_shape<[?,?]> - return %tie1, %shape : tensor, !shapex.ranked_shape<[?,?]> + %d = flow.dispatch @dispatch::@dispatch[%w](%input, %dim0, %dim1) : (tensor{%dim0, %dim1}, index, index) -> tensor{%dim0, %dim1} + return %d, %shape : tensor, !shapex.ranked_shape<[?,?]> } // ----- @@ -87,7 +83,7 @@ func @dependencies_with_dispatch() { %c2 = constant 2 : index %ct3 = constant dense<3> : tensor // CHECK: flow.dispatch - %d0 = flow.dispatch @dispatch0::@dispatch0[%w] () : () -> (tensor) + %d0 = flow.dispatch @dispatch0::@dispatch0[%w]() : () -> tensor // CHECK: addi %add0 = addi %d0, %ct3 : tensor return diff --git a/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions.mlir b/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions.mlir index 933f3b181ae2..a54bb7fd7674 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions.mlir @@ -13,7 +13,7 @@ func @simpleMath(%arg0 : tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 // CHECK-NEXT: %[[R1:.+]] = flow.dispatch.region // CHECK-SAME: [%[[WORKLOAD]] : index] - // CHECK-SAME: (%arg1 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { + // CHECK-SAME: (%arg1 = %arg0 : tensor<4xf32>) -> (tensor<4xf32>) { // CHECK-NEXT: %1 = mhlo.add %arg1, %arg1 : tensor<4xf32> %0 = mhlo.add %arg0, %arg0 : tensor<4xf32> // CHECK-NEXT: flow.return %1 : tensor<4xf32> @@ -29,7 +29,7 @@ func @stdElementwiseOps(%arg0 : tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 // CHECK-NEXT: %[[R1:.+]] = flow.dispatch.region // CHECK-SAME: [%[[WORKLOAD]] : index] - // CHECK-SAME: (%arg1 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { + // CHECK-SAME: (%arg1 = %arg0 : tensor<4xf32>) -> (tensor<4xf32>) { // CHECK-NEXT: %1 = addf %arg1, %arg1 : tensor<4xf32> %0 = addf %arg0, %arg0 : tensor<4xf32> // CHECK-NEXT: %2 = subf %1, %arg1 : tensor<4xf32> @@ -49,7 +49,7 @@ func @hloElementwiseOps(%arg0 : tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 // CHECK-NEXT: %0 = flow.dispatch.region // CHECK-SAME: [%[[WORKLOAD]] : index] - // CHECK-SAME: (%arg1 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { + // CHECK-SAME: (%arg1 = %arg0 : tensor<4xf32>) -> (tensor<4xf32>) { // CHECK-NEXT: %1 = mhlo.add %arg1, %arg1 : tensor<4xf32> %0 = mhlo.add %arg0, %arg0 : tensor<4xf32> // CHECK-NEXT: %2 = mhlo.subtract %1, %arg1 : tensor<4xf32> @@ -73,21 +73,21 @@ func @interleavedDot(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { // CHECK: %[[WORKLOAD2:.+]] = constant 16 : index // CHECK: %[[R0:.+]] = flow.dispatch.region // CHECK-SAME: [%[[WORKLOAD0]] : index] - // CHECK-SAME: (%arg1 = %arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { + // CHECK-SAME: (%arg1 = %arg0 : tensor<4x4xf32>) -> (tensor<4x4xf32>) { // CHECK-NEXT: %3 = mhlo.add %arg1, %arg1 : tensor<4x4xf32> %0 = mhlo.add %arg0, %arg0 : tensor<4x4xf32> // CHECK-NEXT: flow.return %3 : tensor<4x4xf32> // CHECK-NEXT: } // CHECK: %[[R1:.+]] = flow.dispatch.region // CHECK-SAME: [%[[WORKLOAD1]] : index] - // CHECK-SAME: (%arg1 = %[[R0]] : tensor<4x4xf32>, %arg2 = %arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { + // CHECK-SAME: (%arg1 = %[[R0]] : tensor<4x4xf32>, %arg2 = %arg0 : tensor<4x4xf32>) -> (tensor<4x4xf32>) { // CHECK-NEXT: %3 = "mhlo.dot"(%arg1, %arg2) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> %1 = "mhlo.dot"(%0, %arg0) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> // CHECK-NEXT: flow.return %3 : tensor<4x4xf32> // CHECK-NEXT: } // CHECK: %[[R2:.+]] = flow.dispatch.region // CHECK-SAME: [%[[WORKLOAD2]] : index] - // CHECK-SAME: (%arg1 = %[[R1]] : tensor<4x4xf32>, %arg2 = %arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { + // CHECK-SAME: (%arg1 = %[[R1]] : tensor<4x4xf32>, %arg2 = %arg0 : tensor<4x4xf32>) -> (tensor<4x4xf32>) { // CHECK-NEXT: %3 = mhlo.multiply %arg1, %arg2 : tensor<4x4xf32> %2 = mhlo.multiply %1, %arg0 : tensor<4x4xf32> // CHECK-NEXT: flow.return %3 : tensor<4x4xf32> @@ -103,7 +103,7 @@ func @caller(%arg0 : tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: %[[WORKLOAD0:.+]] = constant 4 : index // CHECK-NEXT: %[[R0:.+]] = flow.dispatch.region // CHECK-SAME: [%[[WORKLOAD0]] : index] - // CHECK-SAME: (%arg1 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { + // CHECK-SAME: (%arg1 = %arg0 : tensor<4xf32>) -> (tensor<4xf32>) { // CHECK-NEXT: %1 = mhlo.add %arg1, %arg1 : tensor<4xf32> %0 = mhlo.add %arg0, %arg0 : tensor<4xf32> // CHECK-NEXT: %2 = call @callee(%1) : (tensor<4xf32>) -> tensor<4xf32> @@ -120,7 +120,7 @@ func @callee(%arg0 : tensor<4xf32>) -> tensor<4xf32> { // CHECK: %[[WORKLOAD0:.+]] = constant 4 : index // CHECK: %[[R0:.+]] = flow.dispatch.region // CHECK-SAME: [%[[WORKLOAD0]] : index] - // CHECK-SAME: (%arg1 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { + // CHECK-SAME: (%arg1 = %arg0 : tensor<4xf32>) -> (tensor<4xf32>) { // CHECK-NEXT: %1 = mhlo.multiply %arg1, %arg1 : tensor<4xf32> %0 = mhlo.multiply %arg0, %arg0 : tensor<4xf32> // CHECK-NEXT: flow.return %1 : tensor<4xf32> @@ -138,7 +138,7 @@ func @single_reduction(%arg0 : tensor<4x8xf32>) -> tensor<4xf32> { // CHECK-DAG: %[[WORKLOAD0:.+]] = constant 4 : index // CHECK: %[[RESULT:.+]] = flow.dispatch.region // CHECK-SAME: [%[[WORKLOAD0]] : index] - // CHECK-SAME: (%arg1 = %arg0 : tensor<4x8xf32>, %arg2 = %[[INITIAL]] : tensor) -> tensor<4xf32> + // CHECK-SAME: (%arg1 = %arg0 : tensor<4x8xf32>, %arg2 = %[[INITIAL]] : tensor) -> (tensor<4xf32>) // CHECK-NEXT: = "mhlo.reduce"(%arg1, %arg2) %1 = "mhlo.reduce"(%arg0, %0) ( { ^bb0(%arg1 : tensor, %arg2 : tensor): diff --git a/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_enable_matmul_fusion.mlir b/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_enable_matmul_fusion.mlir index 5ed4cea47ed6..ae8744462327 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_enable_matmul_fusion.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_enable_matmul_fusion.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt -split-input-file -iree-flow-dispatchability-analysis -iree-flow-identify-dispatch-regions2 -iree-enable-consumer-only-fusion %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -iree-flow-dispatchability-analysis -iree-flow-identify-dispatch-regions2 -iree-enable-consumer-only-fusion -canonicalize %s | IreeFileCheck %s func @simpleDotAddMul (%arg0 : tensor<16x32xf32>, %arg1 : tensor<32x48xf32>, diff --git a/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_hlo.mlir b/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_hlo.mlir index 3389a5aef4e2..eb1e8e468b4a 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_hlo.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_hlo.mlir @@ -5,7 +5,7 @@ func @simpleMath(%arg0 : tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 // CHECK-NEXT: %[[R1:.+]] = flow.dispatch.region // CHECK-SAME: [%[[WORKLOAD]] : index] - // CHECK-SAME: (%arg1 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { + // CHECK-SAME: (%arg1 = %arg0 : tensor<4xf32>) -> (tensor<4xf32>) { // CHECK-NEXT: %1 = mhlo.add %arg1, %arg1 : tensor<4xf32> %0 = mhlo.add %arg0, %arg0 : tensor<4xf32> // CHECK-NEXT: flow.return %1 : tensor<4xf32> @@ -55,7 +55,7 @@ func @callee(%arg0 : tensor<4xf32>) -> tensor<4xf32> { // CHECK: %[[WORKLOAD0:.+]] = constant 4 : index // CHECK: %[[R0:.+]] = flow.dispatch.region // CHECK-SAME: [%[[WORKLOAD0]] : index] - // CHECK-SAME: (%arg1 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { + // CHECK-SAME: (%arg1 = %arg0 : tensor<4xf32>) -> (tensor<4xf32>) { // CHECK-NEXT: %1 = mhlo.multiply %arg1, %arg1 : tensor<4xf32> %0 = mhlo.multiply %arg0, %arg0 : tensor<4xf32> // CHECK-NEXT: flow.return %1 : tensor<4xf32> @@ -87,8 +87,9 @@ func @single_reduction(%arg0 : tensor<4x8xf32>) -> tensor<4xf32> { // CHECK-DAG: %[[WORKLOAD0:.+]] = constant 4 : index // CHECK: %[[RESULT:.+]] = flow.dispatch.region // CHECK-SAME: [%[[WORKLOAD0]] : index] - // CHECK-SAME: (%arg1 = %arg0 : tensor<4x8xf32>, %arg2 = %[[INITIAL]] : tensor) -> tensor<4xf32> - // CHECK-NEXT: = "mhlo.reduce"(%arg1, %arg2) + // CHECK-SAME: (%arg1 = %arg0 : tensor<4x8xf32>) -> (tensor<4xf32>) + // CHECK-NEXT: %[[CST_0:.+]] = constant dense<0.0 + // CHECK-NEXT: = "mhlo.reduce"(%arg1, %[[CST_0]]) %1 = "mhlo.reduce"(%arg0, %0) ( { ^bb0(%arg1 : tensor, %arg2 : tensor): %2 = mhlo.add %arg1, %arg2 : tensor @@ -110,8 +111,10 @@ func @multi_reduction(%arg0 : tensor<4x8xf32>, %arg1 : tensor<4x8xf32>) -> (tens // CHECK-DAG: %[[WORKLOAD0:.+]] = constant 4 : index // CHECK: %[[RESULT:.+]]:2 = flow.dispatch.region // CHECK-SAME: [%[[WORKLOAD0]] : index] - // CHECK-SAME: (%arg2 = %arg0 : tensor<4x8xf32>, %arg3 = %arg1 : tensor<4x8xf32>, %arg4 = %[[INITIALA]] : tensor, %arg5 = %[[INITIALB]] : tensor) -> (tensor<4xf32>, tensor<4xf32>) - // CHECK-NEXT: = "mhlo.reduce"(%arg2, %arg3, %arg4, %arg5) + // CHECK-SAME: (%arg2 = %arg0 : tensor<4x8xf32>, %arg3 = %arg1 : tensor<4x8xf32>) -> (tensor<4xf32>, tensor<4xf32>) + // CHECK-NEXT: %[[CST_0:.+]] = constant dense<0.0 + // CHECK-NEXT: %[[CST_1:.+]] = constant dense<1.0 + // CHECK-NEXT: = "mhlo.reduce"(%arg2, %arg3, %[[CST_0]], %[[CST_1]]) %2, %3 = "mhlo.reduce"(%arg0, %arg1, %0, %1) ( { ^bb0(%arg0_lhs : tensor, %arg1_lhs : tensor, %arg0_rhs : tensor, %arg1_rhs : tensor): %4 = mhlo.add %arg0_lhs, %arg0_rhs : tensor @@ -127,7 +130,7 @@ func @multi_reduction(%arg0 : tensor<4x8xf32>, %arg1 : tensor<4x8xf32>) -> (tens // ----- -// CHECK-LABEL: @clone_broadcas +// CHECK-LABEL: @clone_broadcast func @clone_broadcast(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>) -> tensor<4x4xf32> { %splatCst = constant dense<1.0> : tensor // CHECK: flow.dispatch.region @@ -140,7 +143,7 @@ func @clone_broadcast(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>) -> tensor< // CHECK: mhlo.add %0 = "mhlo.broadcast"(%splatCst) {broadcast_sizes = dense<[4, 4]> : tensor<2xi64>} : (tensor) -> tensor<4x4xf32> %1 = "mhlo.add"(%0, %arg0) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> - %2 = "mhlo.dot"(%arg0, %arg1) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + %2 = "mhlo.dot"(%1, %arg1) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> %3 = "mhlo.add"(%0, %2) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> return %3: tensor<4x4xf32> } diff --git a/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_linalg.mlir b/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_linalg.mlir index ef4584275090..1290461ccfa8 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_linalg.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_linalg.mlir @@ -25,29 +25,27 @@ func @constant_capture(%arg0 : tensor<10x20xf32>) -> tensor<10x20xf32> { } // CHECK: func @constant_capture // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<10x20xf32> -// CHECK-DAG: %[[CST1:.+]] = constant 1.000000e+00 : f32 -// CHECK-DAG: %[[CST2:.+]] = constant dense<2.000000e+00> : tensor<10x20xf32> // CHECK-DAG: %[[CST3:.+]] = constant dense<[1.000000e+00, 2.000000e+00, // CHECK-SAME: 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00, // CHECK-SAME: 7.000000e+00, 8.000000e+00, 9.000000e+00, 1.000000e+01]> // CHECK: %[[RESULT:.+]] = flow.dispatch.region[%{{.+}} : index]( // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]] = %[[ARG0]] -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]] = %[[CST2]] -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]] = %[[CST3]] -// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]] = %[[CST1]] -// CHECK-SAME: ) -> tensor<10x20xf32> { -// CHECK: %[[T0:.+]] = linalg.init_tensor [10, 20] : tensor<10x20xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]] = %[[CST3]] +// CHECK-SAME: ) -> (tensor<10x20xf32>) { +// CHECK-DAG: %[[CST1:.+]] = constant 1.000000e+00 : f32 +// CHECK-DAG: %[[CST2:.+]] = constant dense<2.000000e+00> : tensor<10x20xf32> +// CHECK-DAG: %[[T0:.+]] = linalg.init_tensor [10, 20] : tensor<10x20xf32> // CHECK: %[[RETURN:.+]] = linalg.generic -// CHECK-SAME: ins(%[[ARG1]], %[[ARG2]], %[[ARG3]] +// CHECK-SAME: ins(%[[ARG1]], %[[CST2]], %[[ARG2]] // CHECK-SAME: ) outs(%[[T0]] : tensor<10x20xf32>) { // CHECK-NEXT: ^{{[a-zA-Z0-9]+}}( +// CHECK-SAME: %[[ARG3:.[a-zA-Z0-9_]+]]: f32, +// CHECK-SAME: %[[ARG4:.[a-zA-Z0-9_]+]]: f32, // CHECK-SAME: %[[ARG5:.[a-zA-Z0-9_]+]]: f32, -// CHECK-SAME: %[[ARG6:.[a-zA-Z0-9_]+]]: f32, -// CHECK-SAME: %[[ARG7:.[a-zA-Z0-9_]+]]: f32, -// CHECK-SAME: %[[ARG8:.[a-zA-Z0-9_]+]]: f32) -// CHECK: %[[T0:.+]] = addf %[[ARG5]], %[[ARG4]] -// CHECK: %[[T1:.+]] = mulf %[[T0]], %[[ARG6]] -// CHECK: %[[T2:.+]] = addf %[[T1]], %[[ARG7]] +// CHECK-SAME: %[[ARG6:.[a-zA-Z0-9_]+]]: f32) +// CHECK: %[[T0:.+]] = addf %[[ARG3]], %[[CST1]] +// CHECK: %[[T1:.+]] = mulf %[[T0]], %[[ARG4]] +// CHECK: %[[T2:.+]] = addf %[[T1]], %[[ARG5]] // CHECK: linalg.yield %[[T2]] // CHECK: } // CHECK: flow.return %[[RETURN]] diff --git a/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_shapes.mlir b/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_shapes.mlir index fd7566cb537a..4212eaf00e76 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_shapes.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_shapes.mlir @@ -15,7 +15,7 @@ func @singleDispatchWithShapes(%arg0 : tensor, // make generic. // CHECK: %[[R0:.+]] = flow.dispatch.region[%[[UNUSED_WORKLOAD:.+]] : index]( // CHECK-SAME: %[[CA2:.+]] = %[[A2]] : !shapex.ranked_shape<[?,4]>, - // CHECK-SAME: %[[CA0:.+]] = %[[TS0]] : tensor, + // CHECK-SAME: %[[CA0:.+]] = %{{.+}} : tensor, // CHECK-SAME: %[[CA1:.+]] = %[[A1]] : !shapex.ranked_shape<[?,4]>) // Dispatch region should contain captured tie_shapes. // CHECK: %[[R1:.+]] = shapex.tie_shape %[[CA0]], %[[CA1]] diff --git a/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_std_fusion.mlir b/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_std_fusion.mlir index 36568aba73f7..b7219bd2ad3b 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_std_fusion.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_std_fusion.mlir @@ -13,7 +13,7 @@ func @stdElementwiseOps(%arg0 : tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 // CHECK-NEXT: %[[R1:.+]] = flow.dispatch.region // CHECK-SAME: [%[[WORKLOAD]] : index] - // CHECK-SAME: (%arg1 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { + // CHECK-SAME: (%arg1 = %arg0 : tensor<4xf32>) -> (tensor<4xf32>) { // CHECK-NEXT: %1 = addf %arg1, %arg1 : tensor<4xf32> %0 = addf %arg0, %arg0 : tensor<4xf32> // CHECK-NEXT: flow.return %1 : tensor<4xf32> diff --git a/iree/compiler/Dialect/Flow/Transforms/test/inject_dispatch_tracing.mlir b/iree/compiler/Dialect/Flow/Transforms/test/inject_dispatch_tracing.mlir index df5a57fce9dd..d99d312c71c7 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/inject_dispatch_tracing.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/inject_dispatch_tracing.mlir @@ -5,8 +5,8 @@ func @singleDispatch(%arg0: tensor<4xf32>) -> tensor<4xf32> { %c4 = constant 4 : index // CHECK: flow.tensor.trace {key = "ex::entry0 inputs"} %[[ARG0]] : tensor<4xf32> - // CHECK-NEXT: %[[RET0:.+]] = flow.dispatch @ex::@entry0[%c4] (%[[ARG0]]) : (tensor<4xf32>) -> tensor<4xf32> - %0 = flow.dispatch @ex::@entry0[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %[[RET0:.+]] = flow.dispatch @ex::@entry0[%c4](%[[ARG0]]) : (tensor<4xf32>) -> (tensor<4xf32>) + %0 = flow.dispatch @ex::@entry0[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: flow.tensor.trace {key = "ex::entry0 outputs"} %[[RET0]] : tensor<4xf32> // CHECK-NEXT: return %[[RET0]] return %0 : tensor<4xf32> @@ -20,13 +20,13 @@ func @multiDispatch(%arg0: tensor<4xf32>) -> tensor<4xf32> { %c4 = constant 4 : index // CHECK: flow.tensor.trace {key = "ex::entry0 inputs"} %[[ARG0]] : tensor<4xf32> - // CHECK-NEXT: %[[RET0:.+]] = flow.dispatch @ex::@entry0[%c4] (%[[ARG0]]) : (tensor<4xf32>) -> tensor<4xf32> - %0 = flow.dispatch @ex::@entry0[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %[[RET0:.+]] = flow.dispatch @ex::@entry0[%c4](%[[ARG0]]) : (tensor<4xf32>) -> (tensor<4xf32>) + %0 = flow.dispatch @ex::@entry0[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: flow.tensor.trace {key = "ex::entry0 outputs"} %[[RET0]] : tensor<4xf32> // CHECK: flow.tensor.trace {key = "ex::entry1 inputs"} %[[RET0]] : tensor<4xf32> - // CHECK-NEXT: %[[RET1:.+]] = flow.dispatch @ex::@entry1[%c4] (%[[RET0]]) : (tensor<4xf32>) -> tensor<4xf32> - %1 = flow.dispatch @ex::@entry1[%c4] (%0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %[[RET1:.+]] = flow.dispatch @ex::@entry1[%c4](%[[RET0]]) : (tensor<4xf32>) -> (tensor<4xf32>) + %1 = flow.dispatch @ex::@entry1[%c4](%0) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: flow.tensor.trace {key = "ex::entry1 outputs"} %[[RET1]] : tensor<4xf32> // CHECK: return %[[RET1]] diff --git a/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions2.mlir b/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions2.mlir index a58a4f7a96b2..2719aa3f8882 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions2.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions2.mlir @@ -24,9 +24,9 @@ func @staticShapeDispatch(%arg0 : tensor<8x4xf32>) -> tensor<4x8xf32> { %y = constant 50 : index // CHECK: %[[RET:.+]] = flow.dispatch @staticShapeDispatch_dispatch_0::@staticShapeDispatch_dispatch_0[ // CHECK-SAME: %[[X]], %[[Y]] - // CHECK-SAME: ] (%[[ARG0]]) : (tensor<8x4xf32>) -> tensor<4x8xf32> - %0 = flow.dispatch.workgroups[%x, %y](%arg0) : (tensor<8x4xf32>) -> (tensor<4x8xf32>) = ( - %arg : !flow.dispatch.input<8x4xf32>, %ret : !flow.dispatch.output<4x8xf32> + // CHECK-SAME: ](%[[ARG0]]) : (tensor<8x4xf32>) -> (tensor<4x8xf32>) + %0 = flow.dispatch.workgroups[%x, %y](%arg0) : (tensor<8x4xf32>) -> tensor<4x8xf32> = ( + %arg: !flow.dispatch.input<8x4xf32>, %ret: !flow.dispatch.output<4x8xf32> ) { %arg_value = flow.dispatch.input.load %arg : !flow.dispatch.input<8x4xf32> -> tensor<8x4xf32> %arg_shape = flow.dispatch.shape %arg : !flow.dispatch.input<8x4xf32> -> !shapex.ranked_shape<[8,4]> @@ -62,9 +62,9 @@ func @dispatchFnMuli(%arg0 : tensor<8x4xf32>) -> tensor<8x4xf32> { %y = constant 50 : index // CHECK: %[[RET0:.+]] = flow.dispatch @dispatchFnMuli_dispatch_0::@dispatchFnMuli_dispatch_0[ // CHECK-SAME: %[[X]], %[[Y]] - // CHECK-SAME: ] (%[[ARG0]]) : (tensor<8x4xf32>) -> tensor<4x8xf32> + // CHECK-SAME: ](%[[ARG0]]) : (tensor<8x4xf32>) -> (tensor<4x8xf32>) %0 = flow.dispatch.workgroups[%x, %y](%arg0) : (tensor<8x4xf32>) -> (tensor<4x8xf32>) = ( - %arg : !flow.dispatch.input<8x4xf32>, %ret : !flow.dispatch.output<4x8xf32> + %arg: !flow.dispatch.input<8x4xf32>, %ret: !flow.dispatch.output<4x8xf32> ) { %arg_value = flow.dispatch.input.load %arg : !flow.dispatch.input<8x4xf32> -> tensor<8x4xf32> %arg_shape = flow.dispatch.shape %arg : !flow.dispatch.input<8x4xf32> -> !shapex.ranked_shape<[8,4]> @@ -75,9 +75,9 @@ func @dispatchFnMuli(%arg0 : tensor<8x4xf32>) -> tensor<8x4xf32> { } // CHECK: %[[RET1:.+]] = flow.dispatch @dispatchFnMuli_dispatch_1::@dispatchFnMuli_dispatch_1[ // CHECK-SAME: %[[Y]], %[[X]] - // CHECK-SAME: ] (%[[RET0]]) : (tensor<4x8xf32>) -> tensor<8x4xf32> + // CHECK-SAME: ](%[[RET0]]) : (tensor<4x8xf32>) -> (tensor<8x4xf32>) %1 = flow.dispatch.workgroups[%y, %x](%0) : (tensor<4x8xf32>) -> (tensor<8x4xf32>) = ( - %arg : !flow.dispatch.input<4x8xf32>, %ret : !flow.dispatch.output<8x4xf32> + %arg: !flow.dispatch.input<4x8xf32>, %ret: !flow.dispatch.output<8x4xf32> ) { %arg_value = flow.dispatch.input.load %arg : !flow.dispatch.input<4x8xf32> -> tensor<8x4xf32> %arg_shape = flow.dispatch.shape %arg : !flow.dispatch.input<4x8xf32> -> !shapex.ranked_shape<[4,8]> @@ -100,7 +100,7 @@ func @dispatchFn1(%arg0 : tensor<8x4xf32>) -> tensor<4x8xf32> { %y = constant 50 : index // CHECK: flow.dispatch @dispatchFn1_dispatch_0::@dispatchFn1_dispatch_0 %0 = flow.dispatch.workgroups[%x, %y](%arg0) : (tensor<8x4xf32>) -> (tensor<4x8xf32>) = ( - %arg : !flow.dispatch.input<8x4xf32>, %ret : !flow.dispatch.output<4x8xf32> + %arg: !flow.dispatch.input<8x4xf32>, %ret: !flow.dispatch.output<4x8xf32> ) { flow.return } @@ -115,7 +115,7 @@ func @dispatchFn2(%arg0 : tensor<8x4xf32>) -> tensor<4x8xf32> { %y = constant 50 : index // CHECK: flow.dispatch @dispatchFn2_dispatch_0::@dispatchFn2_dispatch_0 %0 = flow.dispatch.workgroups[%x, %y](%arg0) : (tensor<8x4xf32>) -> (tensor<4x8xf32>) = ( - %arg : !flow.dispatch.input<8x4xf32>, %ret : !flow.dispatch.output<4x8xf32> + %arg: !flow.dispatch.input<8x4xf32>, %ret: !flow.dispatch.output<4x8xf32> ) { flow.return } @@ -169,21 +169,17 @@ func @dynamicShapeDispatch(%arg0 : tensor<7x?x24x?xf32>) -> tensor // CHECK-DAG: %[[Y:.+]] = constant 512 %y = constant 512 : index // CHECK-NEXT: %[[ARG0_SHAPE:.+]] = shapex.make_ranked_shape %[[ARG0_DIM1]], %[[ARG0_DIM3]] - %arg0_shape = shapex.make_ranked_shape %dim1, %dim3 : (index, index) -> !shapex.ranked_shape<[7,?,24,?]> - // CHECK-NEXT: %[[ARG0_SHAPED:.+]] = shapex.tie_shape %[[ARG0]], %[[ARG0_SHAPE]] - %arg0_shaped = shapex.tie_shape %arg0, %arg0_shape : tensor<7x?x24x?xf32>, !shapex.ranked_shape<[7,?,24,?]> - // CHECK-NEXT: %[[RET0_SHAPE:.+]] = shapex.make_ranked_shape %[[ARG0_DIM3]], %[[ARG0_DIM1]] - %ret0_shape = shapex.make_ranked_shape %dim3, %dim1 : (index, index) -> !shapex.ranked_shape<[?,?,1024]> // CHECK-DAG: %[[IN_ARG0_DIM1:.+]] = shapex.ranked_dim %[[ARG0_SHAPE]][1] // CHECK-DAG: %[[IN_ARG0_DIM3:.+]] = shapex.ranked_dim %[[ARG0_SHAPE]][3] + // CHECK-NEXT: %[[RET0_SHAPE:.+]] = shapex.make_ranked_shape %[[ARG0_DIM3]], %[[ARG0_DIM1]] // CHECK-DAG: %[[IN_RET0_DIM0:.+]] = shapex.ranked_dim %[[RET0_SHAPE]][0] // CHECK-DAG: %[[IN_RET0_DIM1:.+]] = shapex.ranked_dim %[[RET0_SHAPE]][1] // CHECK-NEXT: %[[RET0:.+]] = flow.dispatch @dynamicShapeDispatch_dispatch_0::@dynamicShapeDispatch_dispatch_0[ // CHECK-SAME: %[[X]], %[[Y]] - // CHECK-SAME: ] (%[[ARG0_SHAPED]], %[[IN_ARG0_DIM1]], %[[IN_ARG0_DIM3]], %[[IN_RET0_DIM0]], %[[IN_RET0_DIM1]]) - // CHECK-SAME: : (tensor<7x?x24x?xf32>, index, index, index, index) -> tensor - %ret0 = flow.dispatch.workgroups[%x, %y](%arg0_shaped) : (tensor<7x?x24x?xf32>) -> tensor = ( - %arg : !flow.dispatch.input<7x?x24x?xf32>, %ret : !flow.dispatch.output + // CHECK-SAME: ](%arg0, %[[IN_ARG0_DIM1]], %[[IN_ARG0_DIM3]], %[[IN_RET0_DIM0]], %[[IN_RET0_DIM1]]) + // CHECK-SAME: : (tensor<7x?x24x?xf32>{%[[IN_ARG0_DIM1]], %[[IN_ARG0_DIM3]]}, index, index, index, index) -> (tensor{%[[IN_RET0_DIM0]], %[[IN_RET0_DIM1]]}) + %ret0 = flow.dispatch.workgroups[%x, %y](%arg0) : (tensor<7x?x24x?xf32>{%dim1, %dim3}) -> tensor{%dim3, %dim1} = ( + %arg: !flow.dispatch.input<7x?x24x?xf32>, %ret: !flow.dispatch.output ) { %workgroup_rank = flow.dispatch.workgroup.rank : index @@ -207,8 +203,6 @@ func @dynamicShapeDispatch(%arg0 : tensor<7x?x24x?xf32>) -> tensor flow.return } - // CHECK-NEXT: %[[RET0_SHAPED:.+]] = shapex.tie_shape %[[RET0]], %[[RET0_SHAPE]] - %ret0_shaped = shapex.tie_shape %ret0, %ret0_shape : tensor, !shapex.ranked_shape<[?,?,1024]> - // CHECK-NEXT: return %[[RET0_SHAPED]] - return %ret0_shaped : tensor + // CHECK-NEXT: return %[[RET0]] + return %ret0 : tensor } diff --git a/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions_ranked_dynamic.mlir b/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions_ranked_dynamic.mlir deleted file mode 100644 index 60c843acac26..000000000000 --- a/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions_ranked_dynamic.mlir +++ /dev/null @@ -1,44 +0,0 @@ -// RUN: iree-opt -allow-unregistered-dialect -split-input-file -iree-flow-outline-dispatch-regions -canonicalize %s | IreeFileCheck %s -// NOTE: Most of the common cases for outlining are tested via -// transformation.mlir; however, this test performs some specific tests -// of corner cases that are easier to access at this level. - -// CHECK-LABEL: @dynamicRankedShapeModule -// Verify that the outlined function properly expands shape dims -// Note that all but the entry shape ties/ops are removed. -// CHECK: flow.executable @dynamicRankedShape_ex_dispatch_0 -// CHECK: func @dynamicRankedShape_ex_dispatch_0(%[[EXARG0:.+]]: tensor<7x?x24x?xf32>, %[[EXARG1:.+]]: index, %[[EXARG2:.+]]: index) -> tensor { -// CHECK-DAG: %[[EXSHAPE0:.+]] = shapex.make_ranked_shape %[[EXARG1]], %[[EXARG2]] -// CHECK-DAG: %[[EXT0:.+]] = shapex.tie_shape %[[EXARG0]], %[[EXSHAPE0]] -// CHECK-DAG: %[[EXT1:.+]] = "some_kind_of_sum"(%[[EXT0]]) -// CHECK-DAG: return %[[EXT1]] -// Verify that the generated flow.dispatch op properly inputs individual shape dims -// CHECK: func @dynamicRankedShape(%[[ARG0:.+]]: tensor<7x?x24x?xf32>) -// CHECK-DAG: %[[C1:.*]] = constant 1 : index -// CHECK-DAG: %[[C3:.*]] = constant 3 : index -// CHECK-DAG: %[[D1:.+]] = dim %[[ARG0]], %[[C1]] -// CHECK-DAG: %[[D3:.+]] = dim %[[ARG0]], %[[C3]] -// CHECK-DAG: %[[WORKLOAD0:.+]] = constant 1024 : index -// CHECK-DAG: %[[DISPATCH:.+]] = flow.dispatch @dynamicRankedShape_ex_dispatch_0::@dynamicRankedShape_ex_dispatch_0[%[[WORKLOAD0]]] (%[[ARG0]], %[[D1]], %[[D3]]) : (tensor<7x?x24x?xf32>, index, index) -// CHECK-DAG: return %[[DISPATCH]] -module @dynamicRankedShapeModule { -func @dynamicRankedShape(%arg0 : tensor<7x?x24x?xf32>) -> tensor { - %c1 = constant 1 : index - %c3 = constant 3 : index - %dim1 = dim %arg0, %c1 : tensor<7x?x24x?xf32> - %dim3 = dim %arg0, %c3 : tensor<7x?x24x?xf32> - %workload0 = constant 1024 : index - %shape0 = shapex.make_ranked_shape %dim1, %dim3 : (index, index) -> !shapex.ranked_shape<[7,?,24,?]> - %1 = flow.dispatch.region[%workload0 : index](%arg1 = %arg0 : tensor<7x?x24x?xf32>, %arg2 = %shape0 : !shapex.ranked_shape<[7,?,24,?]>) -> tensor { - %2 = shapex.tie_shape %arg1, %arg2 : tensor<7x?x24x?xf32>, !shapex.ranked_shape<[7,?,24,?]> - // Simulate a custom op that shuffles the input in a weird way. - %3 = "some_kind_of_sum"(%2) : (tensor<7x?x24x?xf32>) -> tensor - %4 = shapex.ranked_dim %arg2[1] : !shapex.ranked_shape<[7,?,24,?]> -> index - %5 = shapex.ranked_dim %arg2[3] : !shapex.ranked_shape<[7,?,24,?]> -> index - %6 = shapex.make_ranked_shape %4, %5 : (index, index) -> !shapex.ranked_shape<[?,?,1024]> - %7 = shapex.tie_shape %3, %6 : tensor, !shapex.ranked_shape<[?,?,1024]> - flow.return %7 : tensor - } - return %1 : tensor -} -} diff --git a/iree/compiler/Dialect/Flow/Transforms/test/rematerialize_dispatch_constants.mlir b/iree/compiler/Dialect/Flow/Transforms/test/rematerialize_dispatch_constants.mlir index bfc6b7c898ac..d2f434a1429b 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/rematerialize_dispatch_constants.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/rematerialize_dispatch_constants.mlir @@ -5,7 +5,7 @@ func @rematerializeSmall(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { // CHECK: %[[WORKLOAD0:.+]] = constant 16 : index %cst = constant 16 : index %small = constant dense<1.23> : tensor<4x4xf32> - // CHECK: %[[R0:.+]] = flow.dispatch.region[%[[WORKLOAD0]] : index](%arg1 = %arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { + // CHECK: %[[R0:.+]] = flow.dispatch.region[%[[WORKLOAD0]] : index](%arg1 = %arg0 : tensor<4x4xf32>) -> (tensor<4x4xf32>) { %0 = flow.dispatch.region[%cst : index](%arg1 = %arg0 : tensor<4x4xf32>, %arg2 = %small : tensor<4x4xf32>) -> tensor<4x4xf32> { // CHECK-NEXT: %[[REMAT_SMALL:.+]] = constant dense<1.230000e+00> : tensor<4x4xf32> // CHECK-NEXT: %1 = mhlo.add %arg1, %[[REMAT_SMALL]] : tensor<4x4xf32> @@ -22,7 +22,7 @@ func @rematerializeSplat(%arg0 : tensor<1025xi8>) -> tensor<1025xi8> { // CHECK-DAG: %[[WORKLOAD0:.+]] = constant 16 : index %cst = constant 16 : index %large = constant dense<8> : tensor<1025xi8> - // CHECK-NEXT: %[[R0:.+]] = flow.dispatch.region[%[[WORKLOAD0]] : index](%arg1 = %arg0 : tensor<1025xi8>) -> tensor<1025xi8> { + // CHECK-NEXT: %[[R0:.+]] = flow.dispatch.region[%[[WORKLOAD0]] : index](%arg1 = %arg0 : tensor<1025xi8>) -> (tensor<1025xi8>) { %0 = flow.dispatch.region[%cst : index](%arg1 = %arg0 : tensor<1025xi8>, %arg2 = %large : tensor<1025xi8>) -> tensor<1025xi8> { // CHECK-NEXT: %[[REMAT_SPLAT:.+]] = constant dense<8> : tensor<1025xi8> // CHECK-NEXT: %1 = mhlo.add %arg1, %[[REMAT_SPLAT]] : tensor<1025xi8> @@ -40,7 +40,7 @@ func @noRematerializeLarge(%arg0 : tensor<1025xi8>) -> tensor<1025xi8> { // CHECK-DAG: %[[CST:.+]] = constant dense<{{.+}}> : tensor<1025xi8> %cst = constant 16 : index %large = constant dense<[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143,144,145,146,147,148,149,150,151,152,153,154,155,156,157,158,159,160,161,162,163,164,165,166,167,168,169,170,171,172,173,174,175,176,177,178,179,180,181,182,183,184,185,186,187,188,189,190,191,192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207,208,209,210,211,212,213,214,215,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237,238,239,240,241,242,243,244,245,246,247,248,249,250,251,252,253,254,255,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143,144,145,146,147,148,149,150,151,152,153,154,155,156,157,158,159,160,161,162,163,164,165,166,167,168,169,170,171,172,173,174,175,176,177,178,179,180,181,182,183,184,185,186,187,188,189,190,191,192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207,208,209,210,211,212,213,214,215,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237,238,239,240,241,242,243,244,245,246,247,248,249,250,251,252,253,254,255,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143,144,145,146,147,148,149,150,151,152,153,154,155,156,157,158,159,160,161,162,163,164,165,166,167,168,169,170,171,172,173,174,175,176,177,178,179,180,181,182,183,184,185,186,187,188,189,190,191,192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207,208,209,210,211,212,213,214,215,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237,238,239,240,241,242,243,244,245,246,247,248,249,250,251,252,253,254,255,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143,144,145,146,147,148,149,150,151,152,153,154,155,156,157,158,159,160,161,162,163,164,165,166,167,168,169,170,171,172,173,174,175,176,177,178,179,180,181,182,183,184,185,186,187,188,189,190,191,192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207,208,209,210,211,212,213,214,215,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237,238,239,240,241,242,243,244,245,246,247,248,249,250,251,252,253,254,255,0]> : tensor<1025xi8> - // CHECK-NEXT: %[[R0:.+]] = flow.dispatch.region[%[[WORKLOAD0]] : index](%arg1 = %arg0 : tensor<1025xi8>, %arg2 = %[[CST]] : tensor<1025xi8>) -> tensor<1025xi8> { + // CHECK-NEXT: %[[R0:.+]] = flow.dispatch.region[%[[WORKLOAD0]] : index](%arg1 = %arg0 : tensor<1025xi8>, %arg2 = %[[CST]] : tensor<1025xi8>) -> (tensor<1025xi8>) { %0 = flow.dispatch.region[%cst : index](%arg1 = %arg0 : tensor<1025xi8>, %arg2 = %large : tensor<1025xi8>) -> tensor<1025xi8> { // CHECK-NEXT: %1 = mhlo.add %arg1, %arg2 : tensor<1025xi8> %3 = mhlo.add %arg1, %arg2 : tensor<1025xi8> @@ -57,7 +57,7 @@ func @noRematerializeIntoDot(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { // CHECK-DAG: %[[SMALL:.+]] = constant dense<1.230000e+00> : tensor<4x4xf32> %cst = constant 16 : index %small = constant dense<1.23> : tensor<4x4xf32> - // CHECK-NEXT: %[[R0:.+]] = flow.dispatch.region[%[[WORKLOAD0]] : index](%arg1 = %arg0 : tensor<4x4xf32>, %arg2 = %[[SMALL]] : tensor<4x4xf32>) -> tensor<4x4xf32> { + // CHECK-NEXT: %[[R0:.+]] = flow.dispatch.region[%[[WORKLOAD0]] : index](%arg1 = %arg0 : tensor<4x4xf32>, %arg2 = %[[SMALL]] : tensor<4x4xf32>) -> (tensor<4x4xf32>) { %0 = flow.dispatch.region[%cst : index](%arg1 = %arg0 : tensor<4x4xf32>, %arg2 = %small : tensor<4x4xf32>) -> tensor<4x4xf32> { // CHECK-NEXT: %1 = "mhlo.dot"(%arg1, %arg2) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> %3 = "mhlo.dot"(%arg1, %arg2) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> @@ -130,7 +130,8 @@ func @constant_capture(%arg0: tensor<10x20xf32>) -> tensor<10x20xf32> { func @rematerialize_dispatch_workgroups(%arg0: tensor<8x8xf32>, %arg1: tensor<8x8xf32>) -> tensor<8x8xf32> { %cst_0 = constant 0.0 : f32 %c2 = constant 1 : index - %0 = flow.dispatch.workgroups[%c2, %c2, %c2] (%cst_0, %arg0, %arg1) : (f32, tensor<8x8xf32>, tensor<8x8xf32>) -> tensor<8x8xf32> = (%arg2 : f32, %arg3 : !flow.dispatch.input<8x8xf32>, %arg4 : !flow.dispatch.input<8x8xf32>, %arg5 : !flow.dispatch.output<8x8xf32>) { + %0 = flow.dispatch.workgroups[%c2, %c2, %c2](%cst_0, %arg0, %arg1) : (f32, tensor<8x8xf32>, tensor<8x8xf32>) -> tensor<8x8xf32> = + (%arg2: f32, %arg3: !flow.dispatch.input<8x8xf32>, %arg4: !flow.dispatch.input<8x8xf32>, %arg5: !flow.dispatch.output<8x8xf32>) { %c0 = constant 0 : index %c1 = constant 1 : index %c8 = constant 8 : index @@ -147,7 +148,7 @@ func @rematerialize_dispatch_workgroups(%arg0: tensor<8x8xf32>, %arg1: tensor<8x // CHECK: func @rematerialize_dispatch_workgroups(%[[ARG1:.+]]: tensor<8x8xf32>, %[[ARG2:.+]]: tensor<8x8xf32>) // CHECK: %[[CONST1:.+]] = constant 1 : index -// CHECK: flow.dispatch.workgroups[%[[CONST1]], %[[CONST1]], %[[CONST1]]] (%[[ARG1]], %[[ARG2]]) +// CHECK: flow.dispatch.workgroups[%[[CONST1]], %[[CONST1]], %[[CONST1]]](%[[ARG1]], %[[ARG2]]) // CHECK: %[[CONST0:.+]] = constant 0.000000e+00 : f32 // CHECK: %[[INIT_TENSOR:.+]] = linalg.init_tensor [8, 8] : tensor<8x8xf32> // CHECK: linalg.fill(%[[INIT_TENSOR]], %[[CONST0]]) : tensor<8x8xf32>, f32 -> tensor<8x8xf32> diff --git a/iree/compiler/Dialect/Flow/Transforms/test/transformation.mlir b/iree/compiler/Dialect/Flow/Transforms/test/transformation.mlir index b13d93f4103c..5c17a0a28159 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/transformation.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/transformation.mlir @@ -8,31 +8,6 @@ func @empty() { // ----- -func @simpleMath(%arg0 : tensor<4xf32>) -> tensor<4xf32> { - %0 = mhlo.add %arg0, %arg0 : tensor<4xf32> - return %0 : tensor<4xf32> -} - -// CHECK-LABEL: flow.executable @simpleMath_ex_dispatch_0 attributes {sym_visibility = "private"} { -// CHECK-NEXT: flow.dispatch.entry @simpleMath_ex_dispatch_0 -// CHECK-NEXT: module { -// CHECK-NEXT: func @simpleMath_ex_dispatch_0(%arg0: tensor<4xf32>) -> tensor<4xf32> { -// CHECK-NEXT: %0 = mhlo.add %arg0, %arg0 : tensor<4xf32> -// CHECK-NEXT: return %0 : tensor<4xf32> -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: func @simpleMath(%arg0: tensor<4xf32>) -> tensor<4xf32> { -// CHECK-NEXT: %[[WORKLOAD0:.+]] = constant 4 : index -// CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD0]] : index, %arg2 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { -// CHECK-NEXT: %1 = flow.dispatch @simpleMath_ex_dispatch_0::@simpleMath_ex_dispatch_0[%arg1] (%arg2) : (tensor<4xf32>) -> tensor<4xf32> -// CHECK-NEXT: flow.return %1 : tensor<4xf32> -// CHECK-NEXT: } -// CHECK-NEXT: return %0 : tensor<4xf32> -// CHECK-NEXT: } - -// ----- - func @stdElementwiseOps(%arg0 : tensor<4xf32>) -> tensor<4xf32> { %0 = addf %arg0, %arg0 : tensor<4xf32> %1 = subf %0, %arg0 : tensor<4xf32> @@ -52,9 +27,10 @@ func @stdElementwiseOps(%arg0 : tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: func @stdElementwiseOps(%arg0: tensor<4xf32>) -> tensor<4xf32> { -// CHECK-NEXT: %[[WORKLOAD0:.+]] = constant 4 : index -// CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD0]] : index, %arg2 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { -// CHECK-NEXT: %1 = flow.dispatch @stdElementwiseOps_ex_dispatch_0::@stdElementwiseOps_ex_dispatch_0[%arg1] (%arg2) : (tensor<4xf32>) -> tensor<4xf32> +// CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg0) : (tensor<4xf32>) -> (tensor<4xf32>) = +// CHECK-NEXT: (%arg1: tensor<4xf32>) -> (tensor<4xf32>) { +// CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 : index +// CHECK-NEXT: %1 = flow.dispatch @stdElementwiseOps_ex_dispatch_0::@stdElementwiseOps_ex_dispatch_0[%[[WORKLOAD]]](%arg1) : (tensor<4xf32>) -> (tensor<4xf32> // CHECK-NEXT: flow.return %1 : tensor<4xf32> // CHECK-NEXT: } // CHECK-NEXT: return %0 : tensor<4xf32> @@ -81,9 +57,10 @@ func @hloElementwiseOps(%arg0 : tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: func @hloElementwiseOps(%arg0: tensor<4xf32>) -> tensor<4xf32> { -// CHECK-NEXT: %[[WORKLOAD0:.+]] = constant 4 : index -// CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD0]] : index, %arg2 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { -// CHECK-NEXT: %1 = flow.dispatch @hloElementwiseOps_ex_dispatch_0::@hloElementwiseOps_ex_dispatch_0[%arg1] (%arg2) : (tensor<4xf32>) -> tensor<4xf32> +// CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg0) : (tensor<4xf32>) -> (tensor<4xf32>) = +// CHECK-NEXT: (%arg1: tensor<4xf32>) -> (tensor<4xf32>) { +// CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 : index +// CHECK-NEXT: %1 = flow.dispatch @hloElementwiseOps_ex_dispatch_0::@hloElementwiseOps_ex_dispatch_0[%[[WORKLOAD]]](%arg1) : (tensor<4xf32>) -> (tensor<4xf32>) // CHECK-NEXT: flow.return %1 : tensor<4xf32> // CHECK-NEXT: } // CHECK-NEXT: return %0 : tensor<4xf32> @@ -126,11 +103,12 @@ func @interleavedDot(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: func @interleavedDot(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { -// CHECK-NEXT: %[[WORKLOAD0:.+]] = constant 16 : index -// CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD0]] : index, %arg2 = %arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { -// CHECK-NEXT: %1 = flow.dispatch @interleavedDot_ex_dispatch_0::@interleavedDot_ex_dispatch_0[%arg1] (%arg2) : (tensor<4x4xf32>) -> tensor<4x4xf32> -// CHECK-NEXT: %2 = flow.dispatch @interleavedDot_ex_dispatch_1::@interleavedDot_ex_dispatch_1[%arg1] (%1, %arg2) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> -// CHECK-NEXT: %3 = flow.dispatch @interleavedDot_ex_dispatch_2::@interleavedDot_ex_dispatch_2[%arg1] (%2, %arg2) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> +// CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg0) : (tensor<4x4xf32>) -> (tensor<4x4xf32>) = +// CHECK-NEXT: (%arg1: tensor<4x4xf32>) -> (tensor<4x4xf32>) { +// CHECK-NEXT: %[[WORKLOAD:.+]] = constant 16 : index +// CHECK-NEXT: %1 = flow.dispatch @interleavedDot_ex_dispatch_0::@interleavedDot_ex_dispatch_0[%[[WORKLOAD]]](%arg1) : (tensor<4x4xf32>) -> (tensor<4x4xf32>) +// CHECK-NEXT: %2 = flow.dispatch @interleavedDot_ex_dispatch_1::@interleavedDot_ex_dispatch_1[%[[WORKLOAD]]](%1, %arg1) : (tensor<4x4xf32>, tensor<4x4xf32>) -> (tensor<4x4xf32>) +// CHECK-NEXT: %3 = flow.dispatch @interleavedDot_ex_dispatch_2::@interleavedDot_ex_dispatch_2[%[[WORKLOAD]]](%2, %arg1) : (tensor<4x4xf32>, tensor<4x4xf32>) -> (tensor<4x4xf32>) // CHECK-NEXT: flow.return %3 : tensor<4x4xf32> // CHECK-NEXT: } // CHECK-NEXT: return %0 : tensor<4x4xf32> @@ -163,9 +141,10 @@ func @reduction(%arg0 : tensor<4x8xf32>) -> tensor<4xf32> { // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: func @reduction(%arg0: tensor<4x8xf32>) -> tensor<4xf32> { -// CHECK-NEXT: %[[WORKLOAD0:.+]] = constant 4 : index -// CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD0]] : index, %arg2 = %arg0 : tensor<4x8xf32>) -> tensor<4xf32> { -// CHECK-NEXT: %1 = flow.dispatch @reduction_ex_dispatch_0::@reduction_ex_dispatch_0[%arg1] (%arg2) : (tensor<4x8xf32>) -> tensor<4xf32> +// CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg0) : (tensor<4x8xf32>) -> (tensor<4xf32>) = +// CHECK-NEXT: (%arg1: tensor<4x8xf32>) -> (tensor<4xf32>) { +// CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 : index +// CHECK-NEXT: %1 = flow.dispatch @reduction_ex_dispatch_0::@reduction_ex_dispatch_0[%[[WORKLOAD]]](%arg1) : (tensor<4x8xf32>) -> (tensor<4xf32>) // CHECK-NEXT: flow.return %1 : tensor<4xf32> // CHECK-NEXT: } // CHECK-NEXT: return %0 : tensor<4xf32> @@ -189,15 +168,16 @@ func @dynamicUpdateSlice(%operand : tensor<2x4xi32>, %update : tensor<1x1xi32>, // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: func @dynamicUpdateSlice(%arg0: tensor<2x4xi32>, %arg1: tensor<1x1xi32>, %arg2: tensor, %arg3: tensor) -> tensor<2x4xi32> { -// CHECK-DAG: %[[WORKLOAD0:.+]] = constant 8 : index -// CHECK-DAG: %[[ARG2_LOAD:.+]] = flow.tensor.load %arg2 : tensor -// CHECK-DAG: %[[ARG2_INDEX:.+]] = index_cast %[[ARG2_LOAD]] : i32 to index -// CHECK-DAG: %[[ARG3_LOAD:.+]] = flow.tensor.load %arg3 : tensor -// CHECK-DAG: %[[ARG3_INDEX:.+]] = index_cast %[[ARG3_LOAD]] : i32 to index -// CHECK-NEXT: %4 = flow.ex.stream.fragment(%arg4 = %arg1 : tensor<1x1xi32>, %arg5 = %arg0 : tensor<2x4xi32>, %arg6 = %[[ARG2_INDEX]] : index, %arg7 = %[[ARG3_INDEX]] : index, %arg8 = %[[WORKLOAD0]] : index) -> tensor<2x4xi32> { -// CHECK-NEXT: %5 = flow.tensor.update %arg4, %arg5[%arg6, %arg7] : tensor<1x1xi32> -> tensor<2x4xi32> -// CHECK-NEXT: %6 = flow.dispatch @dynamicUpdateSlice_ex_dispatch_0::@dynamicUpdateSlice_ex_dispatch_0[%arg8] (%arg5, %5) : (tensor<2x4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> -// CHECK-NEXT: flow.return %6 : tensor<2x4xi32> +// CHECK-DAG: %[[ARG2_LOAD:.+]] = flow.tensor.load %arg2 : tensor +// CHECK-DAG: %[[ARG2_INDEX:.+]] = index_cast %[[ARG2_LOAD]] : i32 to index +// CHECK-DAG: %[[ARG3_LOAD:.+]] = flow.tensor.load %arg3 : tensor +// CHECK-DAG: %[[ARG3_INDEX:.+]] = index_cast %[[ARG3_LOAD]] : i32 to index +// CHECK-NEXT: %[[RET:.+]] = flow.ex.stream.fragment(%arg0, %[[ARG2_INDEX]], %[[ARG3_INDEX]], %arg1) : (tensor<2x4xi32>, index, index, tensor<1x1xi32>) -> (tensor<2x4xi32>) = +// CHECK-NEXT: (%arg4: tensor<2x4xi32>, %arg5: index, %arg6: index, %arg7: tensor<1x1xi32>) -> (tensor<2x4xi32>) { +// CHECK-NEXT: %[[WORKLOAD:.+]] = constant 8 : index +// CHECK-NEXT: %[[T0:.+]] = flow.tensor.update %arg7, %arg4[%arg5, %arg6] : tensor<1x1xi32> -> tensor<2x4xi32> +// CHECK-NEXT: %[[T1:.+]] = flow.dispatch @dynamicUpdateSlice_ex_dispatch_0::@dynamicUpdateSlice_ex_dispatch_0[%[[WORKLOAD]]](%arg4, %[[T0]]) : (tensor<2x4xi32>, tensor<2x4xi32>) -> (tensor<2x4xi32>) +// CHECK-NEXT: flow.return %[[T1]] : tensor<2x4xi32> // CHECK-NEXT: } -// CHECK-NEXT: return %4 : tensor<2x4xi32> +// CHECK-NEXT: return %[[RET]] : tensor<2x4xi32> // CHECK-NEXT: } diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp index af4c7c38901f..643c52160d5a 100644 --- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp +++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp @@ -172,8 +172,10 @@ static void allocateTransientBuffers(IREE::Flow::ExStreamFragmentOp streamOp, // Pull outputs that terminate on identities to operands. for (auto &op : llvm::reverse(streamOp.body().front())) { if (isIdentityOp(&op)) { - auto result = op.getResult(0); auto operand = op.getOperand(0); + auto result = op.getResult(0); + if (!operand.getType().isa()) continue; + if (!result.getType().isa()) continue; if (bufferSet.rangeMap[result].buffer && !bufferSet.rangeMap[operand].buffer) { LLVM_DEBUG(llvm::dbgs() << " + PROPAGATE IDENTITY RESULT->OPERAND: " @@ -190,6 +192,8 @@ static void allocateTransientBuffers(IREE::Flow::ExStreamFragmentOp streamOp, if (isIdentityOp(&op)) { auto operand = op.getOperand(0); auto result = op.getResult(0); + if (!operand.getType().isa()) continue; + if (!result.getType().isa()) continue; if (bufferSet.rangeMap[operand].buffer && !bufferSet.rangeMap[result].buffer) { LLVM_DEBUG(llvm::dbgs() << " + PROPAGATE IDENTITY OPERAND->RESULT: " @@ -219,6 +223,7 @@ static void allocateTransientBuffers(IREE::Flow::ExStreamFragmentOp streamOp, if (isNoOp(&op) || isIdentityOp(&op)) continue; for (auto it : llvm::enumerate(op.getResults())) { auto result = it.value(); + if (!result.getType().isa()) continue; // If the result is an output buffer we can just use that directly. if (bufferSet.rangeMap[result].buffer) { LLVM_DEBUG(llvm::dbgs() << " -- SKIP ALREADY SET BUFFER RESULT(" @@ -519,6 +524,10 @@ static LogicalResult recordStreamCommands(Value device, Value commandBuffer, } else if (isNoOp(&op) || isIdentityOp(&op)) { // No work to perform. For identity ops, all buffers have been pushed // to "real" ops. + } else if (isa(op)) { + // HACK: all this code is going away soon. + auto newOp = rewriter.clone(op); + op.replaceAllUsesWith(newOp); } else { return op.emitOpError() << "unexpected in stream"; } @@ -532,8 +541,11 @@ class ExStreamFragmentOpConversion using OpConversionPattern< IREE::Flow::ExStreamFragmentOp>::OpConversionPattern; LogicalResult matchAndRewrite( - IREE::Flow::ExStreamFragmentOp streamOp, llvm::ArrayRef operands, + IREE::Flow::ExStreamFragmentOp streamOp, ArrayRef newOperands, ConversionPatternRewriter &rewriter) const override { + IREE::Flow::ExStreamFragmentOp::Adaptor adaptor( + newOperands, streamOp->getAttrDictionary()); + // TODO(benvanik): choose buffer mode/category based on stream commands. auto mode = IREE::HAL::CommandBufferModeBitfield::OneShot; auto category = IREE::HAL::CommandCategoryBitfield::Dispatch | @@ -551,13 +563,13 @@ class ExStreamFragmentOpConversion // Remap non-tensor operands (such as workloads). auto &entryBlock = streamOp.body().front(); - for (int i = 0; i < operands.size(); ++i) { + for (int i = 0; i < adaptor.operands().size(); ++i) { if (streamOp.getOperand(i).getType().isa()) { bufferSet.rangeMap[entryBlock.getArgument(i)] = - BufferRange{operands[i]}; + BufferRange{adaptor.operands()[i]}; } else { rewriter.replaceUsesOfBlockArgument(entryBlock.getArgument(i), - operands[i]); + adaptor.operands()[i]); } } @@ -591,10 +603,10 @@ class ExStreamFragmentOpConversion // It's annoying, but we need to do this replacement at the very end as // otherwise we lose access to the original values (which we need for // shape information). - for (int i = 0; i < operands.size(); ++i) { - if (operands[i].getType().isa()) { + for (int i = 0; i < adaptor.operands().size(); ++i) { + if (adaptor.operands()[i].getType().isa()) { rewriter.replaceUsesOfBlockArgument(entryBlock.getArgument(i), - operands[i]); + adaptor.operands()[i]); } } diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertTensorOps.cpp b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertTensorOps.cpp index b59988f8d44a..591a1b19d8ce 100644 --- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertTensorOps.cpp +++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertTensorOps.cpp @@ -90,7 +90,8 @@ class TensorLoadOpConversion LogicalResult matchAndRewrite( IREE::Flow::TensorLoadOp loadOp, llvm::ArrayRef newOperands, ConversionPatternRewriter &rewriter) const override { - IREE::Flow::TensorLoadOp::Adaptor operands(newOperands); + IREE::Flow::TensorLoadOp::Adaptor operands(newOperands, + loadOp->getAttrDictionary()); auto source = IREE::HAL::TensorRewriteAdaptor::getChecked( loadOp.getLoc(), loadOp.source(), operands.source(), rewriter); if (!source.hasValue()) { @@ -117,7 +118,8 @@ class TensorStoreOpConversion LogicalResult matchAndRewrite( IREE::Flow::TensorStoreOp storeOp, llvm::ArrayRef newOperands, ConversionPatternRewriter &rewriter) const override { - IREE::Flow::TensorStoreOp::Adaptor operands(newOperands); + IREE::Flow::TensorStoreOp::Adaptor operands(newOperands, + storeOp->getAttrDictionary()); auto target = IREE::HAL::TensorRewriteAdaptor::getChecked( storeOp.getLoc(), storeOp.target(), operands.target(), rewriter); diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir index abc7d9f8e8b8..3ba8d894906e 100644 --- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir +++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir @@ -24,16 +24,17 @@ func @multipleDispatches(%arg0: tensor<128xf32>) -> tensor<128xf32> { // CHECK: %[[TMP_BUF:.+]] = hal.allocator.allocate {{.+}}, "DeviceVisible|DeviceLocal", "Transfer|Dispatch" // CHECK: %[[CMD:.+]] = hal.command_buffer.create {{.+}}, OneShot, "Transfer|Dispatch" // CHECK-NEXT: hal.command_buffer.begin %[[CMD]] - %0 = flow.ex.stream.fragment(%arg1 = %cst : index, %arg2 = %arg0 : tensor<128xf32>) -> tensor<128xf32> { + %0 = flow.ex.stream.fragment(%cst, %arg0) : (index, tensor<128xf32>) -> tensor<128xf32> = + (%arg1: index, %arg2: tensor<128xf32>) -> tensor<128xf32> { // CHECK-DAG: %[[EXE_LAYOUT:.+]] = hal.executable_layout.lookup // CHECK: hal.command_buffer.push_descriptor_set %[[CMD]], %[[EXE_LAYOUT]], set = %c0, bindings = [%c0 = (%arg0, %c0, %c512), %c1 = (%[[TMP_BUF]], %c0, %c512)] // CHECK: hal.command_buffer.dispatch.symbol {{.+}}, @ex0::@vmla::@entry0, workgroup_xyz // CHECK: hal.command_buffer.execution_barrier - %1 = flow.dispatch @ex0::@entry0[%arg1] (%arg2) : (tensor<128xf32>) -> tensor<128xf32> + %1 = flow.dispatch @ex0::@entry0[%arg1](%arg2) : (tensor<128xf32>) -> tensor<128xf32> // CHECK: hal.command_buffer.push_descriptor_set // CHECK: hal.command_buffer.dispatch.symbol {{.+}}, @ex0::@vmla::@entry0, workgroup_xyz // CHECK: hal.command_buffer.execution_barrier - %2 = flow.dispatch @ex0::@entry0[%arg1] (%1) : (tensor<128xf32>) -> tensor<128xf32> + %2 = flow.dispatch @ex0::@entry0[%arg1](%1) : (tensor<128xf32>) -> tensor<128xf32> flow.return %2 : tensor<128xf32> } // CHECK: hal.command_buffer.end %[[CMD]] @@ -52,7 +53,8 @@ func @tensorUpdate(%arg0 : tensor<1x1x10xf32>, %arg1 : tensor<5x1x10xf32>) -> te // CHECK: %[[RET_BUF:.+]] = hal.allocator.allocate // CHECK: %[[CMD:.+]] = hal.command_buffer.create // CHECK-NEXT: hal.command_buffer.begin %[[CMD]] - %0 = flow.ex.stream.fragment(%arg2 = %arg0 : tensor<1x1x10xf32>, %arg3 = %arg1 : tensor<5x1x10xf32>, %arg4 = %c4 : index, %arg5 = %c1 : index) -> tensor<5x1x10xf32> { + %0 = flow.ex.stream.fragment(%arg0, %arg1, %c4, %c1) : (tensor<1x1x10xf32>, tensor<5x1x10xf32>, index, index) -> tensor<5x1x10xf32> = + (%arg2: tensor<1x1x10xf32>, %arg3: tensor<5x1x10xf32>, %arg4: index, %arg5: index) -> tensor<5x1x10xf32> { // TODO(laurenzo): Update these checks to be more precise. The regexes can // match too much, masking issues. // CHECK-NEXT: hal.command_buffer.copy_buffer %[[CMD]], %[[TBUF]], %c0, %[[RET_BUF]], %c0, %c200 @@ -95,16 +97,12 @@ func @dispatchWithShapeTies(%arg0: tensor, %bs : index) -> tensor, %arg3 = %bs : index) -> tensor { - %1 = shapex.make_ranked_shape %arg3 : (index) -> !shapex.ranked_shape<[?,128]> - %2 = shapex.tie_shape %arg2, %1 : tensor, !shapex.ranked_shape<[?,128]> - %3 = flow.dispatch @ex0::@entry0[%arg1] (%2, %arg3) : (tensor, index) -> tensor - %4 = shapex.tie_shape %3, %1 : tensor, !shapex.ranked_shape<[?,128]> - %5 = flow.dispatch @ex0::@entry0[%arg1] (%4, %arg3) : (tensor, index) -> tensor - %6 = shapex.tie_shape %5, %1 : tensor, !shapex.ranked_shape<[?,128]> - %7 = flow.dispatch @ex0::@entry0[%arg1] (%6, %arg3) : (tensor, index) -> tensor - %8 = shapex.tie_shape %7, %1 : tensor, !shapex.ranked_shape<[?,128]> - flow.return %8 : tensor + %0 = flow.ex.stream.fragment(%cst, %arg0, %bs) : (index, tensor{%cst}, index) -> tensor{%cst} = + (%arg1: index, %arg2: tensor, %arg3: index) -> tensor { + %3 = flow.dispatch @ex0::@entry0[%arg1](%arg2, %arg3) : (tensor{%arg3}, index) -> tensor{%arg3} + %5 = flow.dispatch @ex0::@entry0[%arg1](%3, %arg3) : (tensor{%arg3}, index) -> tensor{%arg3} + %7 = flow.dispatch @ex0::@entry0[%arg1](%5, %arg3) : (tensor{%arg3}, index) -> tensor{%arg3} + flow.return %7 : tensor } return %0 : tensor } @@ -132,14 +130,11 @@ func @static_tiled_dispatch(%arg0: tensor<7x4x24xf32>) -> tensor<4x7x1024xf32> { %c512 = constant 512 : index // CHECK: %[[CMD:.+]] = hal.command_buffer.create {{.+}}, OneShot, "Transfer|Dispatch" // CHECK-NEXT: hal.command_buffer.begin %[[CMD]] - %1 = flow.ex.stream.fragment( - %arg3 = %arg0 : tensor<7x4x24xf32>, - %arg6 = %c1024 : index, - %arg7 = %c512 : index - ) -> tensor<4x7x1024xf32> { + %1 = flow.ex.stream.fragment(%arg0, %c1024, %c512) : (tensor<7x4x24xf32>, index, index) -> tensor<4x7x1024xf32> = + (%arg3: tensor<7x4x24xf32>, %arg6: index, %arg7: index) -> tensor<4x7x1024xf32> { // CHECK: hal.command_buffer.push_descriptor_set %[[CMD]], %executable_layout, set = %c0, bindings = [%c0 = (%arg0, %c0, %c2688), %c1 = (%buffer, %c0, %c114688)] // CHECK: hal.command_buffer.dispatch.symbol {{.+}}, @ex::@tgt::@entry, workgroup_xyz - %0 = flow.dispatch @ex::@entry[%arg6, %arg7, %arg7] (%arg3) : (tensor<7x4x24xf32>) -> tensor<4x7x1024xf32> + %0 = flow.dispatch @ex::@entry[%arg6, %arg7, %arg7](%arg3) : (tensor<7x4x24xf32>) -> tensor<4x7x1024xf32> flow.return %0 : tensor<4x7x1024xf32> } // CHECK: hal.command_buffer.end %[[CMD]] @@ -169,16 +164,8 @@ func @dynamic_tiled_dispatch(%arg0: tensor<7x?x24x?xf32>, %arg1: index, %arg2: i %c512 = constant 512 : index // CHECK: %[[CMD:.+]] = hal.command_buffer.create {{.+}}, OneShot, "Transfer|Dispatch" // CHECK-NEXT: hal.command_buffer.begin %[[CMD]] - %2 = flow.ex.stream.fragment( - %arg3 = %arg0 : tensor<7x?x24x?xf32>, - %arg4 = %arg1 : index, - %arg5 = %arg2 : index, - %arg6 = %c1024 : index, - %arg7 = %c512 : index - ) -> tensor { - %3 = shapex.make_ranked_shape %arg4, %arg5 : (index, index) -> !shapex.ranked_shape<[7,?,24,?]> - %4 = shapex.make_ranked_shape %arg5, %arg4 : (index, index) -> !shapex.ranked_shape<[?,?,1024]> - %5 = shapex.tie_shape %arg3, %3 : tensor<7x?x24x?xf32>, !shapex.ranked_shape<[7,?,24,?]> + %2 = flow.ex.stream.fragment(%arg0, %arg1, %arg2, %c1024, %c512) : (tensor<7x?x24x?xf32>{%arg1, %arg2}, index, index, index, index) -> tensor{%arg2, %arg1} = + (%arg3: tensor<7x?x24x?xf32>, %arg4: index, %arg5: index, %arg6: index, %arg7: index) -> tensor { // CHECK: hal.command_buffer.push_constants %[[CMD]], %executable_layout, offset = 0, values = [%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}] : i32 // CHECK: hal.command_buffer.push_descriptor_set %[[CMD]], %executable_layout, set = %c0, bindings = [%c0 = (%arg0, %c0, %9), %c1 = (%buffer, %c0, %12)] @@ -199,9 +186,8 @@ func @dynamic_tiled_dispatch(%arg0: tensor<7x?x24x?xf32>, %arg1: index, %arg2: i // CHECK: hal.command_buffer.dispatch.symbol %[[CMD_INNER]], @ex::@tgt::@entry, workgroup_xyz = // CHECK-SAME: [%[[COUNT_X]], %[[COUNT_Y]], %[[COUNT_Z]]] - %6 = flow.dispatch @ex::@entry[%arg6, %arg7, %arg7] (%5, %arg4, %arg5, %arg5, %arg4) : (tensor<7x?x24x?xf32>, index, index, index, index) -> tensor - %7 = shapex.tie_shape %6, %4 : tensor, !shapex.ranked_shape<[?,?,1024]> - flow.return %7 : tensor + %6 = flow.dispatch @ex::@entry[%arg6, %arg7, %arg7](%arg3, %arg4, %arg5, %arg5, %arg4) : (tensor<7x?x24x?xf32>{%arg4, %arg5}, index, index, index, index) -> tensor{%arg5, %arg4} + flow.return %6 : tensor } // CHECK: hal.command_buffer.end %[[CMD]] return %2 : tensor diff --git a/iree/compiler/Dialect/HAL/Target/CUDA/test/smoketest.mlir b/iree/compiler/Dialect/HAL/Target/CUDA/test/smoketest.mlir index ceb0560dd88d..de75c9d1f4cc 100644 --- a/iree/compiler/Dialect/HAL/Target/CUDA/test/smoketest.mlir +++ b/iree/compiler/Dialect/HAL/Target/CUDA/test/smoketest.mlir @@ -1,33 +1,24 @@ // RUN: iree-opt -split-input-file -iree-hal-transformation-pipeline -iree-hal-target-backends=cuda %s | IreeFileCheck %s - #map = affine_map<(d0) -> (d0)> -module { - flow.executable @add_dispatch_0 attributes {sym_visibility = "private"} { - flow.dispatch.entry @add_dispatch_0 attributes {signature = (tensor<16xf32>, tensor<16xf32>) -> tensor<16xf32>, workgroup_rank = 3 : index} - module { - func @add_dispatch_0(%arg0: !flow.dispatch.input<16xf32>, %arg1: !flow.dispatch.input<16xf32>, %arg2: !flow.dispatch.output<16xf32>) { - %0 = linalg.init_tensor [16] : tensor<16xf32> - %1 = flow.dispatch.input.load %arg0 : !flow.dispatch.input<16xf32> -> tensor<16xf32> - %2 = flow.dispatch.input.load %arg1 : !flow.dispatch.input<16xf32> -> tensor<16xf32> - %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%1, %2 : tensor<16xf32>, tensor<16xf32>) outs(%0 : tensor<16xf32>) { - ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors - %4 = addf %arg3, %arg4 : f32 - linalg.yield %4 : f32 - } -> tensor<16xf32> - flow.dispatch.output.store %3, %arg2 : tensor<16xf32> -> !flow.dispatch.output<16xf32> - return - } - } +flow.executable @add_dispatch_0 { + flow.dispatch.entry @add_dispatch_0 attributes { + signature = (tensor<16xf32>, tensor<16xf32>) -> tensor<16xf32>, + workgroup_rank = 3 : index } - func @add(%arg0: tensor<16xf32>, %arg1: tensor<16xf32>) -> tensor<16xf32> attributes {iree.module.export, iree.reflection = {f = "I13!B4!d16B4!d16R7!B4!d16", fv = "1"}} { - %c1 = constant 1 : index - %c16 = constant 16 : index - %0 = flow.ex.stream.fragment(%arg2 = %c16 : index, %arg3 = %c1 : index, %arg4 = %arg0 : tensor<16xf32>, %arg5 = %arg1 : tensor<16xf32>) -> tensor<16xf32> { - %1 = flow.dispatch @add_dispatch_0::@add_dispatch_0[%arg2, %arg3, %arg3] (%arg4, %arg5) : (tensor<16xf32>, tensor<16xf32>) -> tensor<16xf32> - flow.return %1 : tensor<16xf32> + module { + func @add_dispatch_0(%arg0: !flow.dispatch.input<16xf32>, %arg1: !flow.dispatch.input<16xf32>, %arg2: !flow.dispatch.output<16xf32>) { + %0 = linalg.init_tensor [16] : tensor<16xf32> + %1 = flow.dispatch.input.load %arg0 : !flow.dispatch.input<16xf32> -> tensor<16xf32> + %2 = flow.dispatch.input.load %arg1 : !flow.dispatch.input<16xf32> -> tensor<16xf32> + %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%1, %2 : tensor<16xf32>, tensor<16xf32>) outs(%0 : tensor<16xf32>) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors + %4 = addf %arg3, %arg4 : f32 + linalg.yield %4 : f32 + } -> tensor<16xf32> + flow.dispatch.output.store %3, %arg2 : tensor<16xf32> -> !flow.dispatch.output<16xf32> + return } - return %0 : tensor<16xf32> } } diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/test/i1_types.mlir b/iree/compiler/Dialect/HAL/Target/VMLA/test/i1_types.mlir index df5ed384dd7f..16e8a4583354 100644 --- a/iree/compiler/Dialect/HAL/Target/VMLA/test/i1_types.mlir +++ b/iree/compiler/Dialect/HAL/Target/VMLA/test/i1_types.mlir @@ -5,8 +5,9 @@ func @i1_op_usage(%arg0: tensor<4xi1>) -> tensor<4xi1> { %c4 = constant 4 : index // CHECK: %0 = iree.byte_buffer.constant : !iree.byte_buffer = dense<[1, 0, 1, 0]> : tensor<4xi8> %cst = constant dense<[true, false, true, false]> : tensor<4xi1> - %0 = flow.ex.stream.fragment(%arg1 = %c4 : index, %arg2 = %arg0 : tensor<4xi1>, %arg3 = %cst : tensor<4xi1>) -> tensor<4xi1> { - %1 = flow.dispatch @i1_op_usage_ex_dispatch_0::@i1_op_usage_ex_dispatch_0[%arg1] (%arg2, %arg3) : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> + %0 = flow.ex.stream.fragment(%c4, %arg0, %cst) : (index, tensor<4xi1>, tensor<4xi1>) -> (tensor<4xi1>) = + (%arg1: index, %arg2: tensor<4xi1>, %arg3: tensor<4xi1>) -> (tensor<4xi1>) { + %1 = flow.dispatch @i1_op_usage_ex_dispatch_0::@i1_op_usage_ex_dispatch_0[%arg1](%arg2, %arg3) : (tensor<4xi1>, tensor<4xi1>) -> (tensor<4xi1>) flow.return %1 : tensor<4xi1> } return %0 : tensor<4xi1> diff --git a/iree/compiler/Dialect/Shape/IR/ShapeInterfaces.td b/iree/compiler/Dialect/Shape/IR/ShapeInterfaces.td index d9df80db260d..c14a11961f6f 100644 --- a/iree/compiler/Dialect/Shape/IR/ShapeInterfaces.td +++ b/iree/compiler/Dialect/Shape/IR/ShapeInterfaces.td @@ -30,25 +30,27 @@ def Shape_ShapeCarryingOpInterface : OpInterface<"ShapeCarryingInterface"> { let methods = [ StaticInterfaceMethod< - [{Returns a RankedShape for the given shaped result value.}], - "Value", "buildResultValueRankedShape", - (ins "Value":$result, "OpBuilder &":$builder), [{ + /*desc=*/[{Returns a RankedShape for the given shaped result value.}], + /*retTy=*/"Value", + /*methodName=*/"buildResultValueRankedShape", + /*args=*/(ins "Value":$result, "OpBuilder &":$builder), + /*methodBody=*/[{ auto carryingOp = dyn_cast(result.getDefiningOp()); - for (unsigned idx = 0; idx < carryingOp->getNumResults(); ++idx) { - if (carryingOp->getResult(idx) == result) { - return carryingOp.buildResultRankedShape(idx, builder); - } - } - llvm_unreachable("called on result not from a shape-aware-op"); + return carryingOp.buildResultRankedShape( + result.cast().getResultNumber(), builder); }] >, InterfaceMethod< - [{Returns a RankedShape for the given shaped operand index.}], - "Value", "buildOperandRankedShape", (ins "unsigned":$idx, "OpBuilder &":$builder) + /*desc=*/[{Returns a RankedShape for the given shaped operand index.}], + /*retTy=*/"Value", + /*methodName=*/"buildOperandRankedShape", + /*args=*/(ins "unsigned":$idx, "OpBuilder &":$builder) >, InterfaceMethod< - [{Returns a RankedShape for the given shaped result index.}], - "Value", "buildResultRankedShape", (ins "unsigned":$idx, "OpBuilder &":$builder) + /*desc=*/[{Returns a RankedShape for the given shaped result index.}], + /*retTy=*/"Value", + /*methodName=*/"buildResultRankedShape", + /*args=*/(ins "unsigned":$idx, "OpBuilder &":$builder) >, ]; } diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp index c32b9f6ca870..52b35432d0c3 100644 --- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp +++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp @@ -598,7 +598,9 @@ struct DynamicSliceOpConversion }; struct CompareOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + CompareOpConversion(TypeConverter &typeConverter, MLIRContext *context) + : OpConversionPattern(typeConverter, context, + /*benefit=*/9999) {} LogicalResult matchAndRewrite( mhlo::CompareOp srcOp, ArrayRef rawOperands, diff --git a/iree/test/e2e/hackability/flow_partitioned.mlir b/iree/test/e2e/hackability/flow_partitioned.mlir index d9b99f4a0871..7c5513e2f7ce 100644 --- a/iree/test/e2e/hackability/flow_partitioned.mlir +++ b/iree/test/e2e/hackability/flow_partitioned.mlir @@ -15,7 +15,7 @@ flow.executable @ex0 { func @staticShapedFn() -> tensor<4xf32> { %input = iree.unfoldable_constant dense<[-1.0, 2.0, -3.0, 4.0]> : tensor<4xf32> %workload = constant 4 : index - %0 = flow.dispatch @ex0::@dispatch0[%workload] (%input) : (tensor<4xf32>) -> tensor<4xf32> + %0 = flow.dispatch @ex0::@dispatch0[%workload](%input) : (tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } // CHECK: 4xf32=-2 4 -6 8 From d34bdc29d00fa346eb4e4ecf285a946c870f3593 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Wed, 10 Mar 2021 17:10:54 -0800 Subject: [PATCH 07/17] Removing the RematerializeDispatchConstants pass. The ClosureOpInterface canonicalizer does this now. --- iree/compiler/Dialect/Flow/Transforms/BUILD | 1 - .../Dialect/Flow/Transforms/CMakeLists.txt | 1 - .../Dialect/Flow/Transforms/Passes.cpp | 6 - .../compiler/Dialect/Flow/Transforms/Passes.h | 5 - .../RematerializeDispatchConstants.cpp | 274 ------------------ .../rematerialize_dispatch_constants.mlir | 154 ---------- 6 files changed, 441 deletions(-) delete mode 100644 iree/compiler/Dialect/Flow/Transforms/RematerializeDispatchConstants.cpp delete mode 100644 iree/compiler/Dialect/Flow/Transforms/test/rematerialize_dispatch_constants.mlir diff --git a/iree/compiler/Dialect/Flow/Transforms/BUILD b/iree/compiler/Dialect/Flow/Transforms/BUILD index fe910da963bd..720ec1de2abf 100644 --- a/iree/compiler/Dialect/Flow/Transforms/BUILD +++ b/iree/compiler/Dialect/Flow/Transforms/BUILD @@ -46,7 +46,6 @@ cc_library( "OutlineLargeConstants.cpp", "Passes.cpp", "PrePostPartitioningConversion.cpp", - "RematerializeDispatchConstants.cpp", "StripAndSplatConstantVariables.cpp", ], hdrs = [ diff --git a/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt index 61a31fa3c015..97c556d81ef1 100644 --- a/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt +++ b/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt @@ -41,7 +41,6 @@ iree_cc_library( "OutlineLargeConstants.cpp" "Passes.cpp" "PrePostPartitioningConversion.cpp" - "RematerializeDispatchConstants.cpp" "StripAndSplatConstantVariables.cpp" DEPS LLVMSupport diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp index 50204c445600..0e56a132d9cc 100644 --- a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp @@ -210,12 +210,6 @@ void buildFlowTransformPassPipeline(OpPassManager &passManager) { passManager.addNestedPass( IREE::Flow::createFoldCompatibleDispatchRegionsPass()); - // Note that as we are rematerializing things here it's critical we do not run - // the canonicalizer/CSE between now and when we outline - otherwise it'll - // undo all of our work! - passManager.addNestedPass( - IREE::Flow::createRematerializeDispatchConstantsPass()); - // Outline the dispatch regions into their own functions wrapped in // executables. This separates sequencer functions performing dispatches from // dispatchees. diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.h b/iree/compiler/Dialect/Flow/Transforms/Passes.h index e7101f7cdeba..0d2c28b1e8b4 100644 --- a/iree/compiler/Dialect/Flow/Transforms/Passes.h +++ b/iree/compiler/Dialect/Flow/Transforms/Passes.h @@ -124,10 +124,6 @@ std::unique_ptr> createIdentifyDispatchRegions2Pass(); std::unique_ptr> createFoldCompatibleDispatchRegionsPass(); -// Rematerializes small previously-CSE'd constants into dispatch regions. -std::unique_ptr> -createRematerializeDispatchConstantsPass(); - // Outlines dispatch regions into executables. std::unique_ptr> createOutlineDispatchRegionsPass(); std::unique_ptr> createOutlineDispatchRegions2Pass(); @@ -200,7 +196,6 @@ inline void registerFlowPasses() { createIdentifyDispatchRegionsPass(); createIdentifyDispatchRegions2Pass(); createFoldCompatibleDispatchRegionsPass(); - createRematerializeDispatchConstantsPass(); createOutlineDispatchRegionsPass(); createCreateBenchmarkFuncs(); createOutlineLargeConstantsPass(); diff --git a/iree/compiler/Dialect/Flow/Transforms/RematerializeDispatchConstants.cpp b/iree/compiler/Dialect/Flow/Transforms/RematerializeDispatchConstants.cpp deleted file mode 100644 index 4f82b470bf32..000000000000 --- a/iree/compiler/Dialect/Flow/Transforms/RematerializeDispatchConstants.cpp +++ /dev/null @@ -1,274 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" -#include "iree/compiler/Dialect/Flow/Transforms/Passes.h" -#include "llvm/Support/Debug.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/BlockAndValueMapping.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassRegistry.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/Utils.h" - -namespace mlir { -namespace iree_compiler { -namespace IREE { -namespace Flow { - -namespace { - -// Returns true if the constant value is a splat constant and can be -// rematerialized in a dispatch region. -bool isSplatConstant(ConstantOp constantOp) { - if (constantOp.getValue().isa()) { - // Splats are always small and can be much better handled by broadcasting - // within the dispatch regions. - return true; - } else if (auto value = constantOp.getValue().dyn_cast()) { - return value.isSplat(); - } else if (constantOp.getType().isIntOrFloat()) { - return true; - } - - // Assume anything unshaped is small. This may not always be true in custom - // dialects but is in std for now. - return false; -} - -// Returns true if the dispatch region is allowed to have constants inside. -// Certain regions that may get replaced or turned into kernel imports shouldn't -// have the constants moved into them as they'll just get lost. -bool canDispatchRegionContainConstants(DispatchRegionOp dispatchRegionOp) { - for (auto &block : dispatchRegionOp.body()) { - for (auto &op : block) { - // TODO(b/144530470): replace with tablegen attributes/interfaces. - if (isa(&op) || isa(&op)) { - // These two generally result in a lot of generated code so we try to - // keep constants out such that can dedupe more. We may still want to - // allow some parameters in (shapes/etc). - return false; - } - } - } - return true; -} - -// Recursively clones the given |sourceOp| and returns the newly cloned op. -Operation *recursivelyCloneOp(Operation *sourceOp, OpBuilder &builder, - BlockAndValueMapping *mapping) { - // Note that we dedupe required operands in the case of multiple arguments - // coming from the same source operation. - SmallPtrSet operandOps; - for (auto operand : sourceOp->getOperands()) { - operandOps.insert(operand.getDefiningOp()); - } - for (auto *operandOp : operandOps) { - recursivelyCloneOp(operandOp, builder, mapping); - } - return builder.clone(*sourceOp, *mapping); -} - -// Clones the |sourceValue| op tree into |targetBlock|. -// |mapping| is used to lookup existing values that may be present in the block -// such as block arguments or already cloned ancestor ops. |mapping| will be -// updated as the tree is cloned. -Value cloneOpTreeIntoBlock(Value sourceValue, Block *targetBlock, - BlockAndValueMapping *mapping) { - // If the op has already been cloned we can just reuse that. - // This happens if multiple arguments reference the same trees. - if (auto existingValue = mapping->lookupOrNull(sourceValue)) { - return existingValue; - } - - OpBuilder builder = OpBuilder::atBlockEnd(targetBlock); - builder.setInsertionPointToStart(targetBlock); - auto *sourceOp = sourceValue.getDefiningOp(); - auto *clonedOp = recursivelyCloneOp(sourceOp, builder, mapping); - - // Return only the result matching our source value (in the case of multiple - // results). - int resultIndex = std::distance( - sourceOp->result_begin(), - std::find(sourceOp->result_begin(), sourceOp->result_end(), sourceValue)); - return clonedOp->getResult(resultIndex); -} - -// Modify the second operand of the SegmentSize attribute -// TODO(ataei): Remove this once we have flow.dispatch.workgroups only here. -template -void dropOperands(DispatchOpType dispatchOp, ArrayRef deadIndices); - -template <> -void dropOperands(DispatchRegionOp dispatchRegionOp, - ArrayRef deadIndices) { - for (unsigned argIndex : llvm::reverse(deadIndices)) { - dispatchRegionOp.argsMutable().erase(argIndex); - } -} - -template <> -void dropOperands(DispatchWorkgroupsOp dispatchWorkgroupsOp, - ArrayRef deadIndices) { - for (unsigned argIndex : llvm::reverse(deadIndices)) { - dispatchWorkgroupsOp.operandsMutable().erase(argIndex); - } -} - -// Inlines use of the given |value| from outside of a dispatch region to inside -// of it and removes the argument. Supports multiple arguments that reference -// |value| and will clone the entire value tree. -template -LogicalResult inlineDispatchRegionOperandsUsingValue(DispatchOpType dispatchOp, - ValueRange args, - Value value) { - // Find all args that are using this value. - SmallVector argIndices; - for (auto arg : llvm::enumerate(args)) { - if (arg.value() == value) { - argIndices.push_back(arg.index()); - } - } - if (argIndices.empty()) { - // Not used? Wasteful call! - return success(); - } - - // Clone the value (and the ops required to create it) into the entry block. - auto &entryBlock = dispatchOp.body().getBlocks().front(); - BlockAndValueMapping mapping; - auto clonedValue = cloneOpTreeIntoBlock(value, &entryBlock, &mapping); - - // Replace all uses of the inner operand with the new value. - for (unsigned argIndex : argIndices) { - entryBlock.getArgument(argIndex).replaceAllUsesWith(clonedValue); - } - - // Remove the dispatch region args and the block args that have been - // replaced. - for (unsigned argIndex : llvm::reverse(argIndices)) { - entryBlock.eraseArgument(argIndex); - } - dropOperands(dispatchOp, argIndices); - - return success(); -} - -// Rematerializes a constant inside of all dispatch regions that use it. -// Afterward the constant is only removed if there are no other uses within -// the non-dispatch block (such as by sequencer ops). -LogicalResult rematerializeConstantInDispatchRegions(ConstantOp constantOp) { - Value constantValue = constantOp.getResult(); - SmallVector usingRegionOps; - for (auto *user : constantValue.getUsers()) { - if (auto dispatchRegionOp = dyn_cast(user)) { - // Ensure this isn't just the workload and is used as an arg. - if (std::find(dispatchRegionOp.args().begin(), - dispatchRegionOp.args().end(), - constantValue) != dispatchRegionOp.args().end()) { - if (canDispatchRegionContainConstants(dispatchRegionOp)) { - usingRegionOps.push_back(dispatchRegionOp); - } - } - } - } - for (auto &dispatchRegionOp : usingRegionOps) { - if (failed(inlineDispatchRegionOperandsUsingValue( - dispatchRegionOp, dispatchRegionOp.args(), constantValue))) { - return failure(); - } - } - return success(); -} - -LogicalResult rematerializeConstantInDispatchWorkgroupsRegions( - ConstantOp constantOp) { - Value constantValue = constantOp.getResult(); - for (auto *user : constantValue.getUsers()) { - if (auto dispatchWorkgroupsOp = dyn_cast(user)) { - if (failed(inlineDispatchRegionOperandsUsingValue( - dispatchWorkgroupsOp, dispatchWorkgroupsOp.operands(), - constantValue))) { - return failure(); - } - } - } - return success(); -} - -} // namespace - -// Finds constant arguments to dispatch regions that are too small to be worth -// putting into constant pools. This prevents things like a CSE'd scalar -// constant of 0.0 being passed by reference to a bunch of regions. Later -// backend-specific passes running on the dispatch regions may also be able to -// improve their constant propagation chances by having the full constant value -// available. -// -// Note that this currently only operates at the block level. Constants that are -// pushed across branches are assumed to have been rematerialized within blocks -// already, but if that isn't the case then this pass can be extended to do -// that. -class RematerializeDispatchConstantsPass - : public PassWrapper { - public: - void runOnFunction() override { - for (auto &block : getFunction()) { - SmallVector smallConstantOps; - for (auto constantOp : block.getOps()) { - if (isSplatConstant(constantOp)) { - smallConstantOps.push_back(constantOp); - } - } - // Note: we iterate in reverse so that the rematerialized constants appear - // in the same order they did originally (as insertion is at the top). - for (auto constantOp : llvm::reverse(smallConstantOps)) { - if (failed(rematerializeConstantInDispatchRegions(constantOp))) { - return signalPassFailure(); - } - if (failed( - rematerializeConstantInDispatchWorkgroupsRegions(constantOp))) { - return signalPassFailure(); - } - // Remove if there are no other uses within the block. - if (constantOp.use_empty()) { - constantOp.erase(); - } - } - } - } -}; - -std::unique_ptr> -createRematerializeDispatchConstantsPass() { - return std::make_unique(); -} - -static PassRegistration pass( - "iree-flow-rematerialize-dispatch-constants", - "Rematerializes small previously-CSE'd constants into dispatch regions"); - -} // namespace Flow -} // namespace IREE -} // namespace iree_compiler -} // namespace mlir diff --git a/iree/compiler/Dialect/Flow/Transforms/test/rematerialize_dispatch_constants.mlir b/iree/compiler/Dialect/Flow/Transforms/test/rematerialize_dispatch_constants.mlir deleted file mode 100644 index d2f434a1429b..000000000000 --- a/iree/compiler/Dialect/Flow/Transforms/test/rematerialize_dispatch_constants.mlir +++ /dev/null @@ -1,154 +0,0 @@ -// RUN: iree-opt -split-input-file -iree-flow-rematerialize-dispatch-constants %s | IreeFileCheck %s - -// CHECK-LABEL: func @rematerializeSmall -func @rematerializeSmall(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { - // CHECK: %[[WORKLOAD0:.+]] = constant 16 : index - %cst = constant 16 : index - %small = constant dense<1.23> : tensor<4x4xf32> - // CHECK: %[[R0:.+]] = flow.dispatch.region[%[[WORKLOAD0]] : index](%arg1 = %arg0 : tensor<4x4xf32>) -> (tensor<4x4xf32>) { - %0 = flow.dispatch.region[%cst : index](%arg1 = %arg0 : tensor<4x4xf32>, %arg2 = %small : tensor<4x4xf32>) -> tensor<4x4xf32> { - // CHECK-NEXT: %[[REMAT_SMALL:.+]] = constant dense<1.230000e+00> : tensor<4x4xf32> - // CHECK-NEXT: %1 = mhlo.add %arg1, %[[REMAT_SMALL]] : tensor<4x4xf32> - %3 = mhlo.add %arg1, %arg2 : tensor<4x4xf32> - flow.return %3 : tensor<4x4xf32> - } - return %0 : tensor<4x4xf32> -} - -// ----- - -// CHECK-LABEL: func @rematerializeSplat -func @rematerializeSplat(%arg0 : tensor<1025xi8>) -> tensor<1025xi8> { - // CHECK-DAG: %[[WORKLOAD0:.+]] = constant 16 : index - %cst = constant 16 : index - %large = constant dense<8> : tensor<1025xi8> - // CHECK-NEXT: %[[R0:.+]] = flow.dispatch.region[%[[WORKLOAD0]] : index](%arg1 = %arg0 : tensor<1025xi8>) -> (tensor<1025xi8>) { - %0 = flow.dispatch.region[%cst : index](%arg1 = %arg0 : tensor<1025xi8>, %arg2 = %large : tensor<1025xi8>) -> tensor<1025xi8> { - // CHECK-NEXT: %[[REMAT_SPLAT:.+]] = constant dense<8> : tensor<1025xi8> - // CHECK-NEXT: %1 = mhlo.add %arg1, %[[REMAT_SPLAT]] : tensor<1025xi8> - %3 = mhlo.add %arg1, %arg2 : tensor<1025xi8> - flow.return %3 : tensor<1025xi8> - } - return %0 : tensor<1025xi8> -} - -// ----- - -// CHECK-LABEL: func @noRematerializeLarge -func @noRematerializeLarge(%arg0 : tensor<1025xi8>) -> tensor<1025xi8> { - // CHECK-DAG: %[[WORKLOAD0:.+]] = constant 16 : index - // CHECK-DAG: %[[CST:.+]] = constant dense<{{.+}}> : tensor<1025xi8> - %cst = constant 16 : index - %large = constant dense<[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143,144,145,146,147,148,149,150,151,152,153,154,155,156,157,158,159,160,161,162,163,164,165,166,167,168,169,170,171,172,173,174,175,176,177,178,179,180,181,182,183,184,185,186,187,188,189,190,191,192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207,208,209,210,211,212,213,214,215,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237,238,239,240,241,242,243,244,245,246,247,248,249,250,251,252,253,254,255,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143,144,145,146,147,148,149,150,151,152,153,154,155,156,157,158,159,160,161,162,163,164,165,166,167,168,169,170,171,172,173,174,175,176,177,178,179,180,181,182,183,184,185,186,187,188,189,190,191,192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207,208,209,210,211,212,213,214,215,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237,238,239,240,241,242,243,244,245,246,247,248,249,250,251,252,253,254,255,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143,144,145,146,147,148,149,150,151,152,153,154,155,156,157,158,159,160,161,162,163,164,165,166,167,168,169,170,171,172,173,174,175,176,177,178,179,180,181,182,183,184,185,186,187,188,189,190,191,192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207,208,209,210,211,212,213,214,215,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237,238,239,240,241,242,243,244,245,246,247,248,249,250,251,252,253,254,255,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143,144,145,146,147,148,149,150,151,152,153,154,155,156,157,158,159,160,161,162,163,164,165,166,167,168,169,170,171,172,173,174,175,176,177,178,179,180,181,182,183,184,185,186,187,188,189,190,191,192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207,208,209,210,211,212,213,214,215,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237,238,239,240,241,242,243,244,245,246,247,248,249,250,251,252,253,254,255,0]> : tensor<1025xi8> - // CHECK-NEXT: %[[R0:.+]] = flow.dispatch.region[%[[WORKLOAD0]] : index](%arg1 = %arg0 : tensor<1025xi8>, %arg2 = %[[CST]] : tensor<1025xi8>) -> (tensor<1025xi8>) { - %0 = flow.dispatch.region[%cst : index](%arg1 = %arg0 : tensor<1025xi8>, %arg2 = %large : tensor<1025xi8>) -> tensor<1025xi8> { - // CHECK-NEXT: %1 = mhlo.add %arg1, %arg2 : tensor<1025xi8> - %3 = mhlo.add %arg1, %arg2 : tensor<1025xi8> - flow.return %3 : tensor<1025xi8> - } - return %0 : tensor<1025xi8> -} - -// ----- - -// CHECK-LABEL: func @noRematerializeIntoDot -func @noRematerializeIntoDot(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { - // CHECK-DAG: %[[WORKLOAD0:.+]] = constant 16 : index - // CHECK-DAG: %[[SMALL:.+]] = constant dense<1.230000e+00> : tensor<4x4xf32> - %cst = constant 16 : index - %small = constant dense<1.23> : tensor<4x4xf32> - // CHECK-NEXT: %[[R0:.+]] = flow.dispatch.region[%[[WORKLOAD0]] : index](%arg1 = %arg0 : tensor<4x4xf32>, %arg2 = %[[SMALL]] : tensor<4x4xf32>) -> (tensor<4x4xf32>) { - %0 = flow.dispatch.region[%cst : index](%arg1 = %arg0 : tensor<4x4xf32>, %arg2 = %small : tensor<4x4xf32>) -> tensor<4x4xf32> { - // CHECK-NEXT: %1 = "mhlo.dot"(%arg1, %arg2) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> - %3 = "mhlo.dot"(%arg1, %arg2) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> - flow.return %3 : tensor<4x4xf32> - } - return %0 : tensor<4x4xf32> -} - -// ----- - -func @constant_capture(%arg0: tensor<10x20xf32>) -> tensor<10x20xf32> { - %c200 = constant 200 : index - %cst = constant 1.000000e+00 : f32 - %cst_0 = constant dense<2.000000e+00> : tensor<10x20xf32> - %cst_1 = constant dense< - [1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, - 6.000000e+00, 7.000000e+00, 8.000000e+00, 9.000000e+00, 1.000000e+01]> - : tensor<10xf32> - %0 = flow.dispatch.region[%c200 : index] - (%arg1 = %arg0 : tensor<10x20xf32>, %arg2 = %cst_0 : tensor<10x20xf32>, - %arg3 = %cst_1 : tensor<10xf32>, %arg4 = %cst : f32) -> tensor<10x20xf32> { - %1 = linalg.init_tensor [10, 20] : tensor<10x20xf32> - %2 = linalg.generic - {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0)>, - affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"]} - ins(%arg1, %arg2, %arg3 - : tensor<10x20xf32>, tensor<10x20xf32>, tensor<10xf32>) - outs(%1 : tensor<10x20xf32>) { - ^bb0(%arg5: f32, %arg6: f32, %arg7: f32, %arg8: f32): // no predecessors - %3 = addf %arg5, %arg4 : f32 - %4 = mulf %3, %arg6 : f32 - %5 = addf %4, %arg7 : f32 - linalg.yield %5 : f32 - } -> tensor<10x20xf32> - flow.return %2 : tensor<10x20xf32> - } - return %0 : tensor<10x20xf32> -} - -// CHECK-LABEL: func @constant_capture -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<10x20xf32> -// CHECK: %[[CST:.+]] = constant dense<[1.000000e+00, 2.000000e+00, -// CHECK-SAME: 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00, -// CHECK-SAME: 7.000000e+00, 8.000000e+00, 9.000000e+00, 1.000000e+01]> -// CHECK: flow.dispatch.region -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]] = %[[ARG0]] -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]] = %[[CST]] -// CHECK-DAG: %[[CST_0:.+]] = constant 1.000000e+00 : f32 -// CHECK-DAG: %[[CST_1:.+]] = constant dense<2.000000e+00> : tensor<10x20xf32> -// CHECK: %[[RESULT:.+]] = linalg.generic -// CHECK-SAME: ins(%[[ARG1]], %[[CST_1]], %[[ARG2]] -// CHECK-SAME: ) { -// CHECK: ^{{[a-zA-Z0-9_]+}}( -// CHECK-SAME: %[[ARG3:.[a-zA-Z0-9_]+]]: f32, -// CHECK-SAME: %[[ARG4:.[a-zA-Z0-9_]+]]: f32, -// CHECK-SAME: %[[ARG5:.[a-zA-Z0-9_]+]]: f32, -// CHECK-SAME: %[[ARG6:.[a-zA-Z0-9_]+]]: f32) -// CHECK: %[[T0:.+]] = addf %[[ARG3]], %[[CST_0]] -// CHECK: %[[T1:.+]] = mulf %[[T0]], %[[ARG4]] -// CHECK: %[[T2:.+]] = addf %[[T1]], %[[ARG5]] -// CHECK: linalg.yield %[[T2]] -// CHECK: } -// CHECK: flow.return %[[RESULT]] - -// ----- - -func @rematerialize_dispatch_workgroups(%arg0: tensor<8x8xf32>, %arg1: tensor<8x8xf32>) -> tensor<8x8xf32> { - %cst_0 = constant 0.0 : f32 - %c2 = constant 1 : index - %0 = flow.dispatch.workgroups[%c2, %c2, %c2](%cst_0, %arg0, %arg1) : (f32, tensor<8x8xf32>, tensor<8x8xf32>) -> tensor<8x8xf32> = - (%arg2: f32, %arg3: !flow.dispatch.input<8x8xf32>, %arg4: !flow.dispatch.input<8x8xf32>, %arg5: !flow.dispatch.output<8x8xf32>) { - %c0 = constant 0 : index - %c1 = constant 1 : index - %c8 = constant 8 : index - %1 = linalg.init_tensor [8, 8] : tensor<8x8xf32> - %2 = linalg.fill(%1, %arg2) : tensor<8x8xf32>, f32 -> tensor<8x8xf32> - %3 = flow.dispatch.input.load %arg3, offsets = [%c0, %c0], sizes = [%c8, %c8], strides = [%c1, %c1] : !flow.dispatch.input<8x8xf32> -> tensor<8x8xf32> - %4 = flow.dispatch.input.load %arg4, offsets = [%c0, %c0], sizes = [%c8, %c8], strides = [%c1, %c1] : !flow.dispatch.input<8x8xf32> -> tensor<8x8xf32> - %5 = linalg.matmul ins(%3, %4 : tensor<8x8xf32>, tensor<8x8xf32>) outs(%2 : tensor<8x8xf32>) -> tensor<8x8xf32> - flow.dispatch.output.store %5, %arg5, offsets = [%c0, %c0], sizes = [%c8, %c8], strides = [%c1, %c1] : tensor<8x8xf32> -> !flow.dispatch.output<8x8xf32> - flow.return - } - return %0: tensor<8x8xf32> -} - -// CHECK: func @rematerialize_dispatch_workgroups(%[[ARG1:.+]]: tensor<8x8xf32>, %[[ARG2:.+]]: tensor<8x8xf32>) -// CHECK: %[[CONST1:.+]] = constant 1 : index -// CHECK: flow.dispatch.workgroups[%[[CONST1]], %[[CONST1]], %[[CONST1]]](%[[ARG1]], %[[ARG2]]) -// CHECK: %[[CONST0:.+]] = constant 0.000000e+00 : f32 -// CHECK: %[[INIT_TENSOR:.+]] = linalg.init_tensor [8, 8] : tensor<8x8xf32> -// CHECK: linalg.fill(%[[INIT_TENSOR]], %[[CONST0]]) : tensor<8x8xf32>, f32 -> tensor<8x8xf32> From 93a1ff13490596c5add7f09bbeeb817c621d4bad Mon Sep 17 00:00:00 2001 From: Geoffrey Martin-Noble Date: Thu, 11 Mar 2021 10:39:25 -0800 Subject: [PATCH 08/17] Replace all globs in IREE core with `enforce_glob` (#5063) This avoids any globs of source files in CMake (which are [discouraged](https://cmake.org/cmake/help/latest/command/file.html#glob)) and any globs being evaluated in bazel_to_cmake (which would make it depend on files other than the BUILD file). It still allows the safety check that you actually included all the files you meant to (which is particularly useful with tests, where a failure to do so is test skipped forever instead of an immediate build failure). The cost is having to explicitly list a new source file when you add it, which seems not so bad. I didn't remove bazel_to_cmake support for glob yet, so I can clean up any stragglers after this lands. Fixes https://github.com/google/iree/issues/1083 --- build_tools/bazel/deep_copy.bzl | 1 + build_tools/bazel/enforce_glob.bzl | 50 ++++++++++++++++ .../bazel_to_cmake_converter.py | 3 + .../tensorflow/iree_tf_compiler/TF/test/BUILD | 14 ++++- .../iree_tf_compiler/TFL/test/BUILD | 10 +++- .../iree_tf_compiler/TFL/test/import/BUILD | 9 ++- .../dialect/tf_strings/conversion/test/BUILD | 9 ++- .../dialect/tf_strings/ir/BUILD | 9 ++- .../tf_tensorlist/conversion/test/BUILD | 9 ++- .../dialect/tf_tensorlist/ir/BUILD | 9 ++- .../dialect/tf_tensorlist/ir/test/BUILD | 6 +- .../Bindings/TFLite/Transforms/test/BUILD | 9 ++- .../TFLite/Transforms/test/CMakeLists.txt | 4 +- iree/compiler/Conversion/Common/test/BUILD | 10 +++- .../Conversion/Common/test/CMakeLists.txt | 5 +- iree/compiler/Conversion/HLOToHLO/test/BUILD | 9 ++- .../Conversion/HLOToHLO/test/CMakeLists.txt | 4 +- .../Conversion/HLOToLinalg/test/BUILD | 24 +++++++- .../HLOToLinalg/test/CMakeLists.txt | 19 +++++- .../Conversion/LinalgToLLVM/test/BUILD | 16 ++++- .../LinalgToLLVM/test/CMakeLists.txt | 11 +++- .../Conversion/LinalgToNVVM/test/BUILD | 9 ++- .../LinalgToNVVM/test/CMakeLists.txt | 4 +- .../Conversion/LinalgToSPIRV/test/BUILD | 25 +++++++- .../LinalgToSPIRV/test/CMakeLists.txt | 20 ++++++- .../Conversion/LinalgToVector/test/BUILD | 9 ++- .../LinalgToVector/test/CMakeLists.txt | 4 +- .../compiler/Dialect/Flow/Analysis/test/BUILD | 6 +- .../Dialect/Flow/Analysis/test/CMakeLists.txt | 3 +- .../Flow/Conversion/HLOToFlow/test/BUILD | 30 ---------- .../Conversion/HLOToFlow/test/CMakeLists.txt | 24 -------- .../Flow/Conversion/StandardToFlow/test/BUILD | 30 ---------- .../StandardToFlow/test/CMakeLists.txt | 24 -------- iree/compiler/Dialect/Flow/IR/BUILD | 9 ++- iree/compiler/Dialect/Flow/IR/CMakeLists.txt | 1 - iree/compiler/Dialect/Flow/IR/test/BUILD | 20 ++++++- .../Dialect/Flow/IR/test/CMakeLists.txt | 15 ++++- .../Dialect/Flow/Transforms/test/BUILD | 35 ++++++++++- .../Flow/Transforms/test/CMakeLists.txt | 30 +++++++++- .../HAL/Conversion/FlowToHAL/test/BUILD | 11 +++- .../Conversion/FlowToHAL/test/CMakeLists.txt | 6 +- .../HAL/Conversion/HALToHAL/test/BUILD | 6 +- .../Conversion/HALToHAL/test/CMakeLists.txt | 3 +- .../Dialect/HAL/Conversion/HALToVM/test/BUILD | 16 ++++- .../Conversion/HALToVM/test/CMakeLists.txt | 11 +++- .../HAL/Conversion/IREEToHAL/test/BUILD | 6 +- .../Conversion/IREEToHAL/test/CMakeLists.txt | 3 +- .../HAL/Conversion/StandardToHAL/test/BUILD | 6 +- .../StandardToHAL/test/CMakeLists.txt | 3 +- iree/compiler/Dialect/HAL/IR/BUILD | 9 ++- iree/compiler/Dialect/HAL/IR/CMakeLists.txt | 1 - iree/compiler/Dialect/HAL/IR/test/BUILD | 25 +++++++- .../Dialect/HAL/IR/test/CMakeLists.txt | 20 ++++++- .../Dialect/HAL/Target/CUDA/test/BUILD | 6 +- .../HAL/Target/CUDA/test/CMakeLists.txt | 3 +- .../Dialect/HAL/Target/LLVM/test/BUILD | 9 ++- .../HAL/Target/LLVM/test/CMakeLists.txt | 4 +- .../Dialect/HAL/Target/VMLA/test/BUILD | 10 +++- .../HAL/Target/VMLA/test/CMakeLists.txt | 5 +- .../Dialect/HAL/Target/VulkanSPIRV/test/BUILD | 6 +- .../Target/VulkanSPIRV/test/CMakeLists.txt | 3 +- .../Dialect/HAL/Transforms/test/BUILD | 19 +++++- .../HAL/Transforms/test/CMakeLists.txt | 14 ++++- .../Dialect/IREE/Conversion/test/BUILD | 6 +- .../IREE/Conversion/test/CMakeLists.txt | 3 +- iree/compiler/Dialect/IREE/IR/BUILD | 9 ++- iree/compiler/Dialect/IREE/IR/CMakeLists.txt | 1 - iree/compiler/Dialect/IREE/IR/test/BUILD | 10 +++- .../Dialect/IREE/IR/test/CMakeLists.txt | 5 +- .../Dialect/IREE/Transforms/test/BUILD | 6 +- .../IREE/Transforms/test/CMakeLists.txt | 3 +- iree/compiler/Dialect/Modules/Check/IR/BUILD | 6 +- .../Dialect/Modules/Check/IR/CMakeLists.txt | 1 - .../compiler/Dialect/Modules/Check/test/BUILD | 9 ++- .../Dialect/Modules/Check/test/CMakeLists.txt | 4 +- .../Modules/Strings/Conversion/test/BUILD | 6 +- .../Strings/Conversion/test/CMakeLists.txt | 3 +- .../compiler/Dialect/Modules/Strings/IR/BUILD | 6 +- .../Dialect/Modules/Strings/IR/CMakeLists.txt | 1 - .../Dialect/Modules/Strings/IR/test/BUILD | 6 +- .../Modules/Strings/IR/test/CMakeLists.txt | 3 +- .../compiler/Dialect/Modules/TensorList/BUILD | 5 -- .../Dialect/Modules/TensorList/CMakeLists.txt | 1 - .../Modules/TensorList/Conversion/test/BUILD | 9 ++- .../TensorList/Conversion/test/CMakeLists.txt | 4 +- .../Dialect/Modules/TensorList/IR/BUILD | 9 ++- .../Modules/TensorList/IR/CMakeLists.txt | 1 - .../Dialect/Modules/TensorList/IR/test/BUILD | 6 +- .../Modules/TensorList/IR/test/CMakeLists.txt | 3 +- iree/compiler/Dialect/Sequence/IR/BUILD | 9 ++- .../Dialect/Sequence/IR/CMakeLists.txt | 1 - iree/compiler/Dialect/Sequence/IR/test/BUILD | 9 ++- .../Dialect/Sequence/IR/test/CMakeLists.txt | 4 +- .../Dialect/Shape/Conversion/test/BUILD | 6 +- .../Shape/Conversion/test/CMakeLists.txt | 3 +- iree/compiler/Dialect/Shape/IR/BUILD | 9 ++- iree/compiler/Dialect/Shape/IR/CMakeLists.txt | 1 - iree/compiler/Dialect/Shape/IR/test/BUILD | 11 +++- .../Dialect/Shape/IR/test/CMakeLists.txt | 6 +- .../Dialect/Shape/Plugins/VMLA/test/BUILD | 6 +- .../Shape/Plugins/VMLA/test/CMakeLists.txt | 3 +- .../Dialect/Shape/Plugins/XLA/test/BUILD | 6 +- .../Shape/Plugins/XLA/test/CMakeLists.txt | 3 +- .../Dialect/Shape/Transforms/test/BUILD | 14 ++++- .../Shape/Transforms/test/CMakeLists.txt | 9 ++- iree/compiler/Dialect/VM/Analysis/test/BUILD | 9 ++- .../Dialect/VM/Analysis/test/CMakeLists.txt | 4 +- .../Dialect/VM/Conversion/IREEToVM/test/BUILD | 9 ++- .../Conversion/IREEToVM/test/CMakeLists.txt | 4 +- .../VM/Conversion/StandardToVM/test/BUILD | 14 ++++- .../StandardToVM/test/CMakeLists.txt | 9 ++- iree/compiler/Dialect/VM/IR/BUILD | 9 ++- iree/compiler/Dialect/VM/IR/CMakeLists.txt | 1 - iree/compiler/Dialect/VM/IR/test/BUILD | 27 ++++++++- .../Dialect/VM/IR/test/CMakeLists.txt | 22 ++++++- .../Dialect/VM/Target/Bytecode/test/BUILD | 10 +++- .../VM/Target/Bytecode/test/CMakeLists.txt | 5 +- .../compiler/Dialect/VM/Transforms/test/BUILD | 12 +++- .../Dialect/VM/Transforms/test/CMakeLists.txt | 7 ++- .../VMLA/Conversion/HALToVMLA/test/BUILD | 6 +- .../Conversion/HALToVMLA/test/CMakeLists.txt | 3 +- .../VMLA/Conversion/HLOToVMLA/test/BUILD | 21 ++++++- .../Conversion/HLOToVMLA/test/CMakeLists.txt | 16 ++++- .../VMLA/Conversion/StandardToVMLA/test/BUILD | 10 +++- .../StandardToVMLA/test/CMakeLists.txt | 5 +- .../VMLA/Conversion/VMLAToVM/test/BUILD | 9 ++- .../Conversion/VMLAToVM/test/CMakeLists.txt | 4 +- iree/compiler/Dialect/VMLA/IR/BUILD | 9 ++- iree/compiler/Dialect/VMLA/IR/CMakeLists.txt | 1 - iree/compiler/Dialect/VMLA/IR/test/BUILD | 11 +++- .../Dialect/VMLA/IR/test/CMakeLists.txt | 6 +- .../Dialect/VMLA/Transforms/test/BUILD | 10 +++- .../VMLA/Transforms/test/CMakeLists.txt | 5 +- iree/compiler/Dialect/Vulkan/IR/BUILD | 9 ++- .../compiler/Dialect/Vulkan/IR/CMakeLists.txt | 1 - iree/compiler/Dialect/Vulkan/IR/test/BUILD | 6 +- .../Dialect/Vulkan/IR/test/CMakeLists.txt | 3 +- iree/compiler/Dialect/Vulkan/Utils/test/BUILD | 6 +- .../Dialect/Vulkan/Utils/test/CMakeLists.txt | 3 +- iree/compiler/Translation/test/BUILD | 9 ++- iree/compiler/Translation/test/CMakeLists.txt | 4 +- iree/modules/check/test/BUILD | 9 ++- iree/modules/check/test/CMakeLists.txt | 4 +- iree/samples/custom_modules/dialect/BUILD | 6 +- .../custom_modules/dialect/CMakeLists.txt | 1 - .../samples/custom_modules/dialect/test/BUILD | 18 +++--- .../dialect/test/CMakeLists.txt | 4 +- iree/test/e2e/hackability/BUILD | 6 +- iree/test/e2e/hackability/CMakeLists.txt | 3 +- iree/test/e2e/linalg_tensor_ops/BUILD | 13 +++- .../test/e2e/linalg_tensor_ops/CMakeLists.txt | 8 +-- iree/test/e2e/models/BUILD | 15 ++++- iree/test/e2e/models/CMakeLists.txt | 11 ++-- iree/test/e2e/regression/BUILD | 22 ++++++- iree/test/e2e/regression/CMakeLists.txt | 24 ++++---- iree/test/e2e/structural/BUILD | 11 +++- iree/test/e2e/structural/CMakeLists.txt | 6 +- iree/test/e2e/tosa_ops/BUILD | 39 +++++++++++- iree/test/e2e/tosa_ops/CMakeLists.txt | 60 +++++++++++++++++-- iree/test/e2e/vulkan_specific/BUILD | 14 ++++- iree/test/e2e/vulkan_specific/CMakeLists.txt | 11 ++-- iree/test/e2e/xla_ops/BUILD | 53 +++++++++++++++- iree/test/e2e/xla_ops/CMakeLists.txt | 48 ++++++++++++++- iree/tools/test/BUILD | 14 ++++- iree/tools/test/CMakeLists.txt | 9 ++- 165 files changed, 1315 insertions(+), 362 deletions(-) create mode 100644 build_tools/bazel/enforce_glob.bzl delete mode 100644 iree/compiler/Dialect/Flow/Conversion/HLOToFlow/test/BUILD delete mode 100644 iree/compiler/Dialect/Flow/Conversion/HLOToFlow/test/CMakeLists.txt delete mode 100644 iree/compiler/Dialect/Flow/Conversion/StandardToFlow/test/BUILD delete mode 100644 iree/compiler/Dialect/Flow/Conversion/StandardToFlow/test/CMakeLists.txt diff --git a/build_tools/bazel/deep_copy.bzl b/build_tools/bazel/deep_copy.bzl index a7c1c29f6211..32ce6e6ea464 100644 --- a/build_tools/bazel/deep_copy.bzl +++ b/build_tools/bazel/deep_copy.bzl @@ -43,6 +43,7 @@ def _deep_copy_recursion_depth_1(x): def deep_copy(x): """Returns a copy of the argument, making a deep copy if it is a container. + Args: x: (object) value to copy. If it is a container with nested containers as elements, the maximum nesting depth is restricted to three (e.g., diff --git a/build_tools/bazel/enforce_glob.bzl b/build_tools/bazel/enforce_glob.bzl new file mode 100644 index 000000000000..58cb4541087f --- /dev/null +++ b/build_tools/bazel/enforce_glob.bzl @@ -0,0 +1,50 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A utility to enforce that a list matches a glob expression. + +We use this primarily to enable the error-checking capabilities of globs of test +files in IREE while still allowing our Bazel to CMake conversion to not create +CMake globs (which are discouraged for collecting source files, see +https://cmake.org/cmake/help/latest/command/file.html#glob) and not be dependent +on any information outside of the BUILD file. +""" + +def enforce_glob(files, **kwargs): + """A utility to enforce that a list matches a glob expression. + + Note that the comparison is done in an order-independent fashion. + + Args: + files: a list that is expected to contain the same files as the + specified glob expression. + **kwargs: keyword arguments forwarded to the glob. + + Returns: + files. The input argument unchanged + """ + glob_result = native.glob(**kwargs) + + # glob returns a sorted list. + if sorted(files) != glob_result: + glob_result_dict = {k: None for k in glob_result} + result_dict = {k: None for k in files} + missing = [k for k in glob_result if k not in files] + extra = [k for k in files if k not in glob_result] + fail(("Error in enforce_glob." + + "\nExpected {}." + + "\nGot {}." + + "\nMissing {}." + + "\nExtra {}").format(glob_result, files, missing, extra)) + return files diff --git a/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py b/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py index b6d32c23eb78..6a5e852697dd 100644 --- a/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py +++ b/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py @@ -251,6 +251,9 @@ def filegroup(self, name, **kwargs): def sh_binary(self, name, **kwargs): self._convert_unimplemented_function("sh_binary", name) + def enforce_glob(self, files, **kwargs): + return files + def glob(self, include, exclude=None, exclude_directories=1): if exclude_directories != 1: self._convert_unimplemented_function("glob", "with exclude_directories") diff --git a/integrations/tensorflow/iree_tf_compiler/TF/test/BUILD b/integrations/tensorflow/iree_tf_compiler/TF/test/BUILD index 1ed07b949005..60d2030d023b 100644 --- a/integrations/tensorflow/iree_tf_compiler/TF/test/BUILD +++ b/integrations/tensorflow/iree_tf_compiler/TF/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("@iree//iree:lit_test.bzl", "iree_lit_test_suite") +load("@iree//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,18 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "convert_to_mhlo.mlir", + "lower_global_tensors.mlir", + "lower_global_tensors_complex.mlir", + "lower_global_tensors_invalid.mlir", + "propagate_resource_casts.mlir", + "strip_metadata.mlir", + "verify_fully_converted.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree_tf_compiler:iree-tf-opt", "@iree//iree/tools:IreeFileCheck", diff --git a/integrations/tensorflow/iree_tf_compiler/TFL/test/BUILD b/integrations/tensorflow/iree_tf_compiler/TFL/test/BUILD index e9acbed21709..18de285136e7 100644 --- a/integrations/tensorflow/iree_tf_compiler/TFL/test/BUILD +++ b/integrations/tensorflow/iree_tf_compiler/TFL/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("@iree//iree:lit_test.bzl", "iree_lit_test_suite") +load("@iree//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,14 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "convert_metadata.mlir", + "strip_metadata.mlir", + "verify_fully_converted.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree_tf_compiler:iree-opt-tflite", "@iree//iree/tools:IreeFileCheck", diff --git a/integrations/tensorflow/iree_tf_compiler/TFL/test/import/BUILD b/integrations/tensorflow/iree_tf_compiler/TFL/test/import/BUILD index 78398351c8f6..7f2f19fb1936 100644 --- a/integrations/tensorflow/iree_tf_compiler/TFL/test/import/BUILD +++ b/integrations/tensorflow/iree_tf_compiler/TFL/test/import/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("@iree//iree:lit_test.bzl", "iree_lit_test_suite") +load("@iree//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,13 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "add.mlir", + "multi_add.mlir", + ], + include = ["*.mlir"], + ), data = glob(["*.tflite"]) + [ "//iree_tf_compiler:iree-import-tflite", "@iree//iree/tools:IreeFileCheck", diff --git a/integrations/tensorflow/iree_tf_compiler/dialect/tf_strings/conversion/test/BUILD b/integrations/tensorflow/iree_tf_compiler/dialect/tf_strings/conversion/test/BUILD index a2b8e05dad1d..da1fa3f9119e 100644 --- a/integrations/tensorflow/iree_tf_compiler/dialect/tf_strings/conversion/test/BUILD +++ b/integrations/tensorflow/iree_tf_compiler/dialect/tf_strings/conversion/test/BUILD @@ -15,6 +15,7 @@ # Tests for lowering MLIR in various dialects to IREE interpreter bytecode. load("@iree//iree:lit_test.bzl", "iree_lit_test_suite") +load("@iree//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -24,7 +25,13 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "tf_strings_to_strings.mlir", + "tf_to_tf_strings.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree_tf_compiler:iree-tf-opt", "@iree//iree/tools:IreeFileCheck", diff --git a/integrations/tensorflow/iree_tf_compiler/dialect/tf_strings/ir/BUILD b/integrations/tensorflow/iree_tf_compiler/dialect/tf_strings/ir/BUILD index 24e3f79e5a8f..abf320c12d91 100644 --- a/integrations/tensorflow/iree_tf_compiler/dialect/tf_strings/ir/BUILD +++ b/integrations/tensorflow/iree_tf_compiler/dialect/tf_strings/ir/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("@iree//build_tools/bazel:tblgen.bzl", "gentbl") +load("@iree//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,13 @@ package( filegroup( name = "td_files", - srcs = glob(["*.td"]), + srcs = enforce_glob( + [ + "base.td", + "ops.td", + ], + include = ["*.td"], + ), ) gentbl( diff --git a/integrations/tensorflow/iree_tf_compiler/dialect/tf_tensorlist/conversion/test/BUILD b/integrations/tensorflow/iree_tf_compiler/dialect/tf_tensorlist/conversion/test/BUILD index 6b37940aeb03..305e762ef240 100644 --- a/integrations/tensorflow/iree_tf_compiler/dialect/tf_tensorlist/conversion/test/BUILD +++ b/integrations/tensorflow/iree_tf_compiler/dialect/tf_tensorlist/conversion/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("@iree//iree:lit_test.bzl", "iree_lit_test_suite") +load("@iree//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,13 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "convert_tf_tensorlist_to_tensorlist.mlir", + "convert_tf_to_tf_tensorlist.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree_tf_compiler:iree-tf-opt", "@iree//iree/tools:IreeFileCheck", diff --git a/integrations/tensorflow/iree_tf_compiler/dialect/tf_tensorlist/ir/BUILD b/integrations/tensorflow/iree_tf_compiler/dialect/tf_tensorlist/ir/BUILD index c0745fc4c2b8..80b2c552de44 100644 --- a/integrations/tensorflow/iree_tf_compiler/dialect/tf_tensorlist/ir/BUILD +++ b/integrations/tensorflow/iree_tf_compiler/dialect/tf_tensorlist/ir/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("@iree//build_tools/bazel:tblgen.bzl", "gentbl") +load("@iree//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -24,7 +25,13 @@ exports_files(["tf_tensorlist_base.td"]) filegroup( name = "td_files", - srcs = glob(["*.td"]), + srcs = enforce_glob( + [ + "tf_tensorlist_base.td", + "tf_tensorlist_ops.td", + ], + include = ["*.td"], + ), ) cc_library( diff --git a/integrations/tensorflow/iree_tf_compiler/dialect/tf_tensorlist/ir/test/BUILD b/integrations/tensorflow/iree_tf_compiler/dialect/tf_tensorlist/ir/test/BUILD index 6b37940aeb03..b700c9f6e652 100644 --- a/integrations/tensorflow/iree_tf_compiler/dialect/tf_tensorlist/ir/test/BUILD +++ b/integrations/tensorflow/iree_tf_compiler/dialect/tf_tensorlist/ir/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("@iree//iree:lit_test.bzl", "iree_lit_test_suite") +load("@iree//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,10 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + ["ops.mlir"], + include = ["*.mlir"], + ), data = [ "//iree_tf_compiler:iree-tf-opt", "@iree//iree/tools:IreeFileCheck", diff --git a/iree/compiler/Bindings/TFLite/Transforms/test/BUILD b/iree/compiler/Bindings/TFLite/Transforms/test/BUILD index b780b112a316..95904a789942 100644 --- a/iree/compiler/Bindings/TFLite/Transforms/test/BUILD +++ b/iree/compiler/Bindings/TFLite/Transforms/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,13 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "materialize_shape_support.mlir", + "wrap_entry_points.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Bindings/TFLite/Transforms/test/CMakeLists.txt b/iree/compiler/Bindings/TFLite/Transforms/test/CMakeLists.txt index 96f36fe84237..365af3dbd2db 100644 --- a/iree/compiler/Bindings/TFLite/Transforms/test/CMakeLists.txt +++ b/iree/compiler/Bindings/TFLite/Transforms/test/CMakeLists.txt @@ -10,12 +10,12 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "materialize_shape_support.mlir" + "wrap_entry_points.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Conversion/Common/test/BUILD b/iree/compiler/Conversion/Common/test/BUILD index 1e3b7bba341b..0dbfe82569f8 100644 --- a/iree/compiler/Conversion/Common/test/BUILD +++ b/iree/compiler/Conversion/Common/test/BUILD @@ -15,6 +15,7 @@ # Tests for common transforms. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -24,7 +25,14 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "linalg_bufferize.mlir", + "linalg_rewrite_destructive_updates.mlir", + "remove_dead_allocs.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Conversion/Common/test/CMakeLists.txt b/iree/compiler/Conversion/Common/test/CMakeLists.txt index 899242936781..5684d46df19c 100644 --- a/iree/compiler/Conversion/Common/test/CMakeLists.txt +++ b/iree/compiler/Conversion/Common/test/CMakeLists.txt @@ -10,12 +10,13 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "linalg_bufferize.mlir" + "linalg_rewrite_destructive_updates.mlir" + "remove_dead_allocs.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Conversion/HLOToHLO/test/BUILD b/iree/compiler/Conversion/HLOToHLO/test/BUILD index 1e3b7bba341b..ab0ba6be34c1 100644 --- a/iree/compiler/Conversion/HLOToHLO/test/BUILD +++ b/iree/compiler/Conversion/HLOToHLO/test/BUILD @@ -15,6 +15,7 @@ # Tests for common transforms. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -24,7 +25,13 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "conv1x12dot.mlir", + "f32Tof16.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Conversion/HLOToHLO/test/CMakeLists.txt b/iree/compiler/Conversion/HLOToHLO/test/CMakeLists.txt index 420c06116ccb..50035da4b3e0 100644 --- a/iree/compiler/Conversion/HLOToHLO/test/CMakeLists.txt +++ b/iree/compiler/Conversion/HLOToHLO/test/CMakeLists.txt @@ -10,12 +10,12 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "conv1x12dot.mlir" + "f32Tof16.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Conversion/HLOToLinalg/test/BUILD b/iree/compiler/Conversion/HLOToLinalg/test/BUILD index 1e3b7bba341b..4b1bc253c185 100644 --- a/iree/compiler/Conversion/HLOToLinalg/test/BUILD +++ b/iree/compiler/Conversion/HLOToLinalg/test/BUILD @@ -15,6 +15,7 @@ # Tests for common transforms. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -24,7 +25,28 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "arithmetic_ops.mlir", + "concatenate.mlir", + "decompose_hlo_clamp.mlir", + "depthwise_conv.mlir", + "dot.mlir", + "dynamic_shape.mlir", + "exp.mlir", + "fusion.mlir", + "linalg_tensor_to_buffer.mlir", + "pad_tensor.mlir", + "pad_tensor_to_tensor.mlir", + "pipeline_test.mlir", + "reduce.mlir", + "reduce_window.mlir", + "subtensor.mlir", + "subtensor_insert.mlir", + "torch_index_select.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Conversion/HLOToLinalg/test/CMakeLists.txt b/iree/compiler/Conversion/HLOToLinalg/test/CMakeLists.txt index 927972fd2126..2f24788e795a 100644 --- a/iree/compiler/Conversion/HLOToLinalg/test/CMakeLists.txt +++ b/iree/compiler/Conversion/HLOToLinalg/test/CMakeLists.txt @@ -10,12 +10,27 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "arithmetic_ops.mlir" + "concatenate.mlir" + "decompose_hlo_clamp.mlir" + "depthwise_conv.mlir" + "dot.mlir" + "dynamic_shape.mlir" + "exp.mlir" + "fusion.mlir" + "linalg_tensor_to_buffer.mlir" + "pad_tensor.mlir" + "pad_tensor_to_tensor.mlir" + "pipeline_test.mlir" + "reduce.mlir" + "reduce_window.mlir" + "subtensor.mlir" + "subtensor_insert.mlir" + "torch_index_select.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/BUILD b/iree/compiler/Conversion/LinalgToLLVM/test/BUILD index 1e3b7bba341b..7df28f0a92f5 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/test/BUILD +++ b/iree/compiler/Conversion/LinalgToLLVM/test/BUILD @@ -15,6 +15,7 @@ # Tests for common transforms. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -24,7 +25,20 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "conv_img2col.mlir", + "hal_interface_bindings.mlir", + "hal_interface_constants.mlir", + "hal_interface_workgroup_info.mlir", + "linalg_vectorize.mlir", + "materialize_launch_configuration.mlir", + "matmul_vectorization.mlir", + "plan_conv_loop_order.mlir", + "tile_and_distribute.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/CMakeLists.txt b/iree/compiler/Conversion/LinalgToLLVM/test/CMakeLists.txt index 725edceb656d..661c4ab205c5 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/test/CMakeLists.txt +++ b/iree/compiler/Conversion/LinalgToLLVM/test/CMakeLists.txt @@ -10,12 +10,19 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "conv_img2col.mlir" + "hal_interface_bindings.mlir" + "hal_interface_constants.mlir" + "hal_interface_workgroup_info.mlir" + "linalg_vectorize.mlir" + "materialize_launch_configuration.mlir" + "matmul_vectorization.mlir" + "plan_conv_loop_order.mlir" + "tile_and_distribute.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Conversion/LinalgToNVVM/test/BUILD b/iree/compiler/Conversion/LinalgToNVVM/test/BUILD index 675ed9130dc0..a8cd50f554f3 100644 --- a/iree/compiler/Conversion/LinalgToNVVM/test/BUILD +++ b/iree/compiler/Conversion/LinalgToNVVM/test/BUILD @@ -15,6 +15,7 @@ # Tests for common transforms. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -24,7 +25,13 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "convert_to_nvvm.mlir", + "pipeline_test.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Conversion/LinalgToNVVM/test/CMakeLists.txt b/iree/compiler/Conversion/LinalgToNVVM/test/CMakeLists.txt index 3779e3cabe7d..2d850ad956ec 100644 --- a/iree/compiler/Conversion/LinalgToNVVM/test/CMakeLists.txt +++ b/iree/compiler/Conversion/LinalgToNVVM/test/CMakeLists.txt @@ -10,12 +10,12 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "convert_to_nvvm.mlir" + "pipeline_test.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/BUILD b/iree/compiler/Conversion/LinalgToSPIRV/test/BUILD index 1e3b7bba341b..8ed1c59a3487 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/BUILD +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/BUILD @@ -15,6 +15,7 @@ # Tests for common transforms. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -24,7 +25,29 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "batch_matmul_vectorization.mlir", + "convert_to_gpu.mlir", + "convert_to_spirv.mlir", + "dead_alloc.mlir", + "fold-gpu-procid-uses.mlir", + "forop_canonicalization.mlir", + "linalg_tile_and_fuse.mlir", + "materialize_launch_configuration.mlir", + "materialize_launch_configuration2.mlir", + "matmul_fused_vectorization.mlir", + "matmul_vectorization.mlir", + "matmul_vectorization_licm.mlir", + "memref_vecrotization.mlir", + "pipeline_test.mlir", + "pipeline_test_cooperative_mat.mlir", + "split_dispatch_function.mlir", + "vector_to_gpu.mlir", + "workgroup_memory_promotion.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/CMakeLists.txt b/iree/compiler/Conversion/LinalgToSPIRV/test/CMakeLists.txt index 8dbfbf23eff0..30df4bd36c9e 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/CMakeLists.txt +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/CMakeLists.txt @@ -10,12 +10,28 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "batch_matmul_vectorization.mlir" + "convert_to_gpu.mlir" + "convert_to_spirv.mlir" + "dead_alloc.mlir" + "fold-gpu-procid-uses.mlir" + "forop_canonicalization.mlir" + "linalg_tile_and_fuse.mlir" + "materialize_launch_configuration.mlir" + "materialize_launch_configuration2.mlir" + "matmul_fused_vectorization.mlir" + "matmul_vectorization.mlir" + "matmul_vectorization_licm.mlir" + "memref_vecrotization.mlir" + "pipeline_test.mlir" + "pipeline_test_cooperative_mat.mlir" + "split_dispatch_function.mlir" + "vector_to_gpu.mlir" + "workgroup_memory_promotion.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Conversion/LinalgToVector/test/BUILD b/iree/compiler/Conversion/LinalgToVector/test/BUILD index e124cf534051..2997186606bf 100644 --- a/iree/compiler/Conversion/LinalgToVector/test/BUILD +++ b/iree/compiler/Conversion/LinalgToVector/test/BUILD @@ -15,6 +15,7 @@ # Tests for common transforms. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -24,7 +25,13 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "vectorize_linalg_conv.mlir", + "vectorize_linalg_ops.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Conversion/LinalgToVector/test/CMakeLists.txt b/iree/compiler/Conversion/LinalgToVector/test/CMakeLists.txt index 05e4f22cc2ec..b958f27656d3 100644 --- a/iree/compiler/Conversion/LinalgToVector/test/CMakeLists.txt +++ b/iree/compiler/Conversion/LinalgToVector/test/CMakeLists.txt @@ -10,12 +10,12 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "vectorize_linalg_conv.mlir" + "vectorize_linalg_ops.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/Flow/Analysis/test/BUILD b/iree/compiler/Dialect/Flow/Analysis/test/BUILD index b780b112a316..8f34c74559e2 100644 --- a/iree/compiler/Dialect/Flow/Analysis/test/BUILD +++ b/iree/compiler/Dialect/Flow/Analysis/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,10 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + ["dispatchability.mlir"], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/Flow/Analysis/test/CMakeLists.txt b/iree/compiler/Dialect/Flow/Analysis/test/CMakeLists.txt index c10e0a2033ed..26aafe0ab62f 100644 --- a/iree/compiler/Dialect/Flow/Analysis/test/CMakeLists.txt +++ b/iree/compiler/Dialect/Flow/Analysis/test/CMakeLists.txt @@ -10,12 +10,11 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "dispatchability.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/Flow/Conversion/HLOToFlow/test/BUILD b/iree/compiler/Dialect/Flow/Conversion/HLOToFlow/test/BUILD deleted file mode 100644 index b780b112a316..000000000000 --- a/iree/compiler/Dialect/Flow/Conversion/HLOToFlow/test/BUILD +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//iree:lit_test.bzl", "iree_lit_test_suite") - -package( - default_visibility = ["//visibility:public"], - features = ["layering_check"], - licenses = ["notice"], # Apache 2.0 -) - -iree_lit_test_suite( - name = "lit", - srcs = glob(["*.mlir"]), - data = [ - "//iree/tools:IreeFileCheck", - "//iree/tools:iree-opt", - ], -) diff --git a/iree/compiler/Dialect/Flow/Conversion/HLOToFlow/test/CMakeLists.txt b/iree/compiler/Dialect/Flow/Conversion/HLOToFlow/test/CMakeLists.txt deleted file mode 100644 index 609b1785dc6e..000000000000 --- a/iree/compiler/Dialect/Flow/Conversion/HLOToFlow/test/CMakeLists.txt +++ /dev/null @@ -1,24 +0,0 @@ -################################################################################ -# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # -# iree/compiler/Dialect/Flow/Conversion/HLOToFlow/test/BUILD # -# # -# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # -# CMake-only content. # -# # -# To disable autogeneration for this file entirely, delete this header. # -################################################################################ - -iree_add_all_subdirs() - -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) -iree_lit_test_suite( - NAME - lit - SRCS - "${_GLOB_X_MLIR}" - DATA - iree::tools::IreeFileCheck - iree::tools::iree-opt -) - -### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/iree/compiler/Dialect/Flow/Conversion/StandardToFlow/test/BUILD b/iree/compiler/Dialect/Flow/Conversion/StandardToFlow/test/BUILD deleted file mode 100644 index b780b112a316..000000000000 --- a/iree/compiler/Dialect/Flow/Conversion/StandardToFlow/test/BUILD +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//iree:lit_test.bzl", "iree_lit_test_suite") - -package( - default_visibility = ["//visibility:public"], - features = ["layering_check"], - licenses = ["notice"], # Apache 2.0 -) - -iree_lit_test_suite( - name = "lit", - srcs = glob(["*.mlir"]), - data = [ - "//iree/tools:IreeFileCheck", - "//iree/tools:iree-opt", - ], -) diff --git a/iree/compiler/Dialect/Flow/Conversion/StandardToFlow/test/CMakeLists.txt b/iree/compiler/Dialect/Flow/Conversion/StandardToFlow/test/CMakeLists.txt deleted file mode 100644 index 637de90aa834..000000000000 --- a/iree/compiler/Dialect/Flow/Conversion/StandardToFlow/test/CMakeLists.txt +++ /dev/null @@ -1,24 +0,0 @@ -################################################################################ -# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # -# iree/compiler/Dialect/Flow/Conversion/StandardToFlow/test/BUILD # -# # -# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # -# CMake-only content. # -# # -# To disable autogeneration for this file entirely, delete this header. # -################################################################################ - -iree_add_all_subdirs() - -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) -iree_lit_test_suite( - NAME - lit - SRCS - "${_GLOB_X_MLIR}" - DATA - iree::tools::IreeFileCheck - iree::tools::iree-opt -) - -### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/iree/compiler/Dialect/Flow/IR/BUILD b/iree/compiler/Dialect/Flow/IR/BUILD index e6c4c6a630c5..940064be7e4d 100644 --- a/iree/compiler/Dialect/Flow/IR/BUILD +++ b/iree/compiler/Dialect/Flow/IR/BUILD @@ -14,6 +14,7 @@ load("//build_tools/bazel:iree_tablegen_doc.bzl", "iree_tablegen_doc") load("//build_tools/bazel:tblgen.bzl", "gentbl") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -23,7 +24,13 @@ package( filegroup( name = "td_files", - srcs = glob(["*.td"]), + srcs = enforce_glob( + [ + "FlowBase.td", + "FlowOps.td", + ], + include = ["*.td"], + ), ) cc_library( diff --git a/iree/compiler/Dialect/Flow/IR/CMakeLists.txt b/iree/compiler/Dialect/Flow/IR/CMakeLists.txt index 581bae16396e..3f7e8b68608a 100644 --- a/iree/compiler/Dialect/Flow/IR/CMakeLists.txt +++ b/iree/compiler/Dialect/Flow/IR/CMakeLists.txt @@ -10,7 +10,6 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_TD LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.td) iree_cc_library( NAME IR diff --git a/iree/compiler/Dialect/Flow/IR/test/BUILD b/iree/compiler/Dialect/Flow/IR/test/BUILD index b780b112a316..b2f3b8d3239a 100644 --- a/iree/compiler/Dialect/Flow/IR/test/BUILD +++ b/iree/compiler/Dialect/Flow/IR/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,24 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "dispatch_ops.mlir", + "dispatch_region_folding.mlir", + "dispatch_regions.mlir", + "dispatch_workgroups.mlir", + "dispatch_workgroups_folding.mlir", + "executable_ops.mlir", + "stream_folding.mlir", + "stream_ops.mlir", + "tensor_folding.mlir", + "tensor_ops.mlir", + "types.mlir", + "variable_folding.mlir", + "variable_ops.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/Flow/IR/test/CMakeLists.txt b/iree/compiler/Dialect/Flow/IR/test/CMakeLists.txt index 28f5faeabaa1..a1b98324a6bc 100644 --- a/iree/compiler/Dialect/Flow/IR/test/CMakeLists.txt +++ b/iree/compiler/Dialect/Flow/IR/test/CMakeLists.txt @@ -10,12 +10,23 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "dispatch_ops.mlir" + "dispatch_region_folding.mlir" + "dispatch_regions.mlir" + "dispatch_workgroups.mlir" + "dispatch_workgroups_folding.mlir" + "executable_ops.mlir" + "stream_folding.mlir" + "stream_ops.mlir" + "tensor_folding.mlir" + "tensor_ops.mlir" + "types.mlir" + "variable_folding.mlir" + "variable_ops.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/Flow/Transforms/test/BUILD b/iree/compiler/Dialect/Flow/Transforms/test/BUILD index b780b112a316..048f9f65da28 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/BUILD +++ b/iree/compiler/Dialect/Flow/Transforms/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,39 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "create_benchmark_funcs.mlir", + "deduplicate_executables.mlir", + "dispatch_linalg_on_tensors.mlir", + "dispatch_linalg_on_tensors_dynamic.mlir", + "expand_variable_dynamic_dims.mlir", + "fold_compatible_dispatch_regions.mlir", + "form_streams.mlir", + "hlo_to_hlo_preprocessing.mlir", + "hlo_to_hlo_preprocessing_canoncalize_dot_general.mlir", + "hlo_to_hlo_preprocessing_extract_pad_from_conv.mlir", + "hoist_unstreamable_ops.mlir", + "identify_dispatch_regions.mlir", + "identify_dispatch_regions2_enable_matmul_fusion.mlir", + "identify_dispatch_regions2_hlo.mlir", + "identify_dispatch_regions2_linalg.mlir", + "identify_dispatch_regions2_shapes.mlir", + "identify_dispatch_regions2_std_fusion.mlir", + "inject_dispatch_tracing.mlir", + "legalize_input_types.mlir", + "materialize_and_merge_exported_reflection.mlir", + "materialize_exported_reflection.mlir", + "merge_exported_reflection.mlir", + "outline_dispatch_regions2.mlir", + "outline_dispatch_regions_ranked_dynamic.mlir", + "outline_large_constants.mlir", + "rematerialize_dispatch_constants.mlir", + "strip_and_splat_constant_variables.mlir", + "transformation.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt b/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt index ddcc3cde95d4..f1cf8b711124 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt +++ b/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt @@ -10,12 +10,38 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "create_benchmark_funcs.mlir" + "deduplicate_executables.mlir" + "dispatch_linalg_on_tensors.mlir" + "dispatch_linalg_on_tensors_dynamic.mlir" + "expand_variable_dynamic_dims.mlir" + "fold_compatible_dispatch_regions.mlir" + "form_streams.mlir" + "hlo_to_hlo_preprocessing.mlir" + "hlo_to_hlo_preprocessing_canoncalize_dot_general.mlir" + "hlo_to_hlo_preprocessing_extract_pad_from_conv.mlir" + "hoist_unstreamable_ops.mlir" + "identify_dispatch_regions.mlir" + "identify_dispatch_regions2_enable_matmul_fusion.mlir" + "identify_dispatch_regions2_hlo.mlir" + "identify_dispatch_regions2_linalg.mlir" + "identify_dispatch_regions2_shapes.mlir" + "identify_dispatch_regions2_std_fusion.mlir" + "inject_dispatch_tracing.mlir" + "legalize_input_types.mlir" + "materialize_and_merge_exported_reflection.mlir" + "materialize_exported_reflection.mlir" + "merge_exported_reflection.mlir" + "outline_dispatch_regions2.mlir" + "outline_dispatch_regions_ranked_dynamic.mlir" + "outline_large_constants.mlir" + "rematerialize_dispatch_constants.mlir" + "strip_and_splat_constant_variables.mlir" + "transformation.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/BUILD b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/BUILD index b780b112a316..e55486242ff6 100644 --- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/BUILD +++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,15 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "constant_ops.mlir", + "stream_ops.mlir", + "tensor_ops.mlir", + "variable_ops.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/CMakeLists.txt b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/CMakeLists.txt index 17b455554a40..ceb1dcf6a018 100644 --- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/CMakeLists.txt +++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/CMakeLists.txt @@ -10,12 +10,14 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "constant_ops.mlir" + "stream_ops.mlir" + "tensor_ops.mlir" + "variable_ops.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/BUILD b/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/BUILD index b780b112a316..f63519c37096 100644 --- a/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/BUILD +++ b/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,10 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + ["constant_ops.mlir"], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/CMakeLists.txt b/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/CMakeLists.txt index 910ac9537da0..063d8b0269a4 100644 --- a/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/CMakeLists.txt +++ b/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/CMakeLists.txt @@ -10,12 +10,11 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "constant_ops.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/BUILD b/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/BUILD index b780b112a316..94f1dc8bd15e 100644 --- a/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/BUILD +++ b/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,20 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "allocator_ops.mlir", + "buffer_ops.mlir", + "buffer_view_ops.mlir", + "command_buffer_ops.mlir", + "constant_ops.mlir", + "control_flow_ops.mlir", + "device_ops.mlir", + "executable_ops.mlir", + "variable_ops.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/CMakeLists.txt b/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/CMakeLists.txt index 7844d24c0128..fa16efe8eca9 100644 --- a/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/CMakeLists.txt +++ b/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/CMakeLists.txt @@ -10,12 +10,19 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "allocator_ops.mlir" + "buffer_ops.mlir" + "buffer_view_ops.mlir" + "command_buffer_ops.mlir" + "constant_ops.mlir" + "control_flow_ops.mlir" + "device_ops.mlir" + "executable_ops.mlir" + "variable_ops.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/HAL/Conversion/IREEToHAL/test/BUILD b/iree/compiler/Dialect/HAL/Conversion/IREEToHAL/test/BUILD index b780b112a316..a9b6fa17b0c6 100644 --- a/iree/compiler/Dialect/HAL/Conversion/IREEToHAL/test/BUILD +++ b/iree/compiler/Dialect/HAL/Conversion/IREEToHAL/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,10 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + ["shape_constants.mlir"], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/HAL/Conversion/IREEToHAL/test/CMakeLists.txt b/iree/compiler/Dialect/HAL/Conversion/IREEToHAL/test/CMakeLists.txt index cee582d2149d..a6025156f003 100644 --- a/iree/compiler/Dialect/HAL/Conversion/IREEToHAL/test/CMakeLists.txt +++ b/iree/compiler/Dialect/HAL/Conversion/IREEToHAL/test/CMakeLists.txt @@ -10,12 +10,11 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "shape_constants.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/BUILD b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/BUILD index b780b112a316..51722bc38906 100644 --- a/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/BUILD +++ b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,10 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + ["structural_ops.mlir"], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/CMakeLists.txt b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/CMakeLists.txt index ed79fde93dc2..aec20652b290 100644 --- a/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/CMakeLists.txt +++ b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/CMakeLists.txt @@ -10,12 +10,11 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "structural_ops.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/HAL/IR/BUILD b/iree/compiler/Dialect/HAL/IR/BUILD index d759f0e3ea2e..37671708432b 100644 --- a/iree/compiler/Dialect/HAL/IR/BUILD +++ b/iree/compiler/Dialect/HAL/IR/BUILD @@ -14,6 +14,7 @@ load("//build_tools/bazel:iree_tablegen_doc.bzl", "iree_tablegen_doc") load("//build_tools/bazel:tblgen.bzl", "gentbl") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -25,7 +26,13 @@ exports_files(["HALBase.td"]) filegroup( name = "td_files", - srcs = glob(["*.td"]), + srcs = enforce_glob( + [ + "HALBase.td", + "HALOps.td", + ], + include = ["*.td"], + ), ) cc_library( diff --git a/iree/compiler/Dialect/HAL/IR/CMakeLists.txt b/iree/compiler/Dialect/HAL/IR/CMakeLists.txt index 06072e3697e2..d548abe427ff 100644 --- a/iree/compiler/Dialect/HAL/IR/CMakeLists.txt +++ b/iree/compiler/Dialect/HAL/IR/CMakeLists.txt @@ -10,7 +10,6 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_TD LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.td) iree_cc_library( NAME IR diff --git a/iree/compiler/Dialect/HAL/IR/test/BUILD b/iree/compiler/Dialect/HAL/IR/test/BUILD index b780b112a316..a4e7c0ba2d58 100644 --- a/iree/compiler/Dialect/HAL/IR/test/BUILD +++ b/iree/compiler/Dialect/HAL/IR/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,29 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "allocator_ops.mlir", + "attributes.mlir", + "buffer_folding.mlir", + "buffer_ops.mlir", + "buffer_view_folding.mlir", + "buffer_view_ops.mlir", + "command_buffer_folding.mlir", + "command_buffer_ops.mlir", + "constant_folding.mlir", + "constant_ops.mlir", + "descriptor_set_ops.mlir", + "device_ops.mlir", + "executable_ops.mlir", + "experimental_ops.mlir", + "interface_ops.mlir", + "semaphore_ops.mlir", + "variable_folding.mlir", + "variable_ops.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/HAL/IR/test/CMakeLists.txt b/iree/compiler/Dialect/HAL/IR/test/CMakeLists.txt index 578bd7d49161..172247e7f333 100644 --- a/iree/compiler/Dialect/HAL/IR/test/CMakeLists.txt +++ b/iree/compiler/Dialect/HAL/IR/test/CMakeLists.txt @@ -10,12 +10,28 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "allocator_ops.mlir" + "attributes.mlir" + "buffer_folding.mlir" + "buffer_ops.mlir" + "buffer_view_folding.mlir" + "buffer_view_ops.mlir" + "command_buffer_folding.mlir" + "command_buffer_ops.mlir" + "constant_folding.mlir" + "constant_ops.mlir" + "descriptor_set_ops.mlir" + "device_ops.mlir" + "executable_ops.mlir" + "experimental_ops.mlir" + "interface_ops.mlir" + "semaphore_ops.mlir" + "variable_folding.mlir" + "variable_ops.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/HAL/Target/CUDA/test/BUILD b/iree/compiler/Dialect/HAL/Target/CUDA/test/BUILD index 326145d04b6c..793f1d4e084f 100644 --- a/iree/compiler/Dialect/HAL/Target/CUDA/test/BUILD +++ b/iree/compiler/Dialect/HAL/Target/CUDA/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,10 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + ["smoketest.mlir"], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/HAL/Target/CUDA/test/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/CUDA/test/CMakeLists.txt index c461923966d9..229736f77813 100644 --- a/iree/compiler/Dialect/HAL/Target/CUDA/test/CMakeLists.txt +++ b/iree/compiler/Dialect/HAL/Target/CUDA/test/CMakeLists.txt @@ -10,12 +10,11 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "smoketest.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/test/BUILD b/iree/compiler/Dialect/HAL/Target/LLVM/test/BUILD index bb9c0e0808d7..8d49ad95b14e 100644 --- a/iree/compiler/Dialect/HAL/Target/LLVM/test/BUILD +++ b/iree/compiler/Dialect/HAL/Target/LLVM/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,13 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "binary_op.mlir", + "matmul_op.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/test/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/LLVM/test/CMakeLists.txt index 766c6e249488..c43c1cdf01f5 100644 --- a/iree/compiler/Dialect/HAL/Target/LLVM/test/CMakeLists.txt +++ b/iree/compiler/Dialect/HAL/Target/LLVM/test/CMakeLists.txt @@ -10,12 +10,12 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "binary_op.mlir" + "matmul_op.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/test/BUILD b/iree/compiler/Dialect/HAL/Target/VMLA/test/BUILD index bb9c0e0808d7..fa884d612e5a 100644 --- a/iree/compiler/Dialect/HAL/Target/VMLA/test/BUILD +++ b/iree/compiler/Dialect/HAL/Target/VMLA/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,14 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "i1_types.mlir", + "linking.mlir", + "smoketest.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/test/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/VMLA/test/CMakeLists.txt index 8aba5262972f..4a238ee67584 100644 --- a/iree/compiler/Dialect/HAL/Target/VMLA/test/CMakeLists.txt +++ b/iree/compiler/Dialect/HAL/Target/VMLA/test/CMakeLists.txt @@ -10,12 +10,13 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "i1_types.mlir" + "linking.mlir" + "smoketest.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/test/BUILD b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/test/BUILD index bb9c0e0808d7..0e3700209bf6 100644 --- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/test/BUILD +++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,10 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + ["smoketest.mlir"], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/test/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/test/CMakeLists.txt index 29cc1273de8f..c2c0cf180b0c 100644 --- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/test/CMakeLists.txt +++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/test/CMakeLists.txt @@ -10,12 +10,11 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "smoketest.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/HAL/Transforms/test/BUILD b/iree/compiler/Dialect/HAL/Transforms/test/BUILD index b780b112a316..f902b2514044 100644 --- a/iree/compiler/Dialect/HAL/Transforms/test/BUILD +++ b/iree/compiler/Dialect/HAL/Transforms/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,23 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "benchmark_batch_dispatches.mlir", + "cse_variable_loads.mlir", + "identify_constant_pools.mlir", + "inline_device_switches.mlir", + "materialize_constant_pool_buffers.mlir", + "materialize_interfaces.mlir", + "materialize_resource_caches.mlir", + "memoize_device_queries.mlir", + "pack_constant_pool_storage.mlir", + "propagate_constant_workgroup_info.mlir", + "public_abi_generation.mlir", + "resolve_entry_point_ordinals.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt b/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt index b8831ee32130..676ac1c612a1 100644 --- a/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt +++ b/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt @@ -10,12 +10,22 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "benchmark_batch_dispatches.mlir" + "cse_variable_loads.mlir" + "identify_constant_pools.mlir" + "inline_device_switches.mlir" + "materialize_constant_pool_buffers.mlir" + "materialize_interfaces.mlir" + "materialize_resource_caches.mlir" + "memoize_device_queries.mlir" + "pack_constant_pool_storage.mlir" + "propagate_constant_workgroup_info.mlir" + "public_abi_generation.mlir" + "resolve_entry_point_ordinals.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/IREE/Conversion/test/BUILD b/iree/compiler/Dialect/IREE/Conversion/test/BUILD index b780b112a316..21bed501e8cf 100644 --- a/iree/compiler/Dialect/IREE/Conversion/test/BUILD +++ b/iree/compiler/Dialect/IREE/Conversion/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,10 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + ["convert_flow_to_hal.mlir"], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/IREE/Conversion/test/CMakeLists.txt b/iree/compiler/Dialect/IREE/Conversion/test/CMakeLists.txt index 203b15530bf5..5dc210ff6429 100644 --- a/iree/compiler/Dialect/IREE/Conversion/test/CMakeLists.txt +++ b/iree/compiler/Dialect/IREE/Conversion/test/CMakeLists.txt @@ -10,12 +10,11 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "convert_flow_to_hal.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/IREE/IR/BUILD b/iree/compiler/Dialect/IREE/IR/BUILD index e04be16fe782..d9da85ae50b4 100644 --- a/iree/compiler/Dialect/IREE/IR/BUILD +++ b/iree/compiler/Dialect/IREE/IR/BUILD @@ -14,6 +14,7 @@ load("//build_tools/bazel:iree_tablegen_doc.bzl", "iree_tablegen_doc") load("//build_tools/bazel:tblgen.bzl", "gentbl") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -25,7 +26,13 @@ exports_files(["IREEBase.td"]) filegroup( name = "td_files", - srcs = glob(["*.td"]), + srcs = enforce_glob( + [ + "IREEBase.td", + "IREEOps.td", + ], + include = ["*.td"], + ), ) cc_library( diff --git a/iree/compiler/Dialect/IREE/IR/CMakeLists.txt b/iree/compiler/Dialect/IREE/IR/CMakeLists.txt index 14f21a4c7294..1758d1cd6254 100644 --- a/iree/compiler/Dialect/IREE/IR/CMakeLists.txt +++ b/iree/compiler/Dialect/IREE/IR/CMakeLists.txt @@ -10,7 +10,6 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_TD LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.td) iree_cc_library( NAME IR diff --git a/iree/compiler/Dialect/IREE/IR/test/BUILD b/iree/compiler/Dialect/IREE/IR/test/BUILD index b780b112a316..aacb0e0f1def 100644 --- a/iree/compiler/Dialect/IREE/IR/test/BUILD +++ b/iree/compiler/Dialect/IREE/IR/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,14 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "byte_buffer_ops.mlir", + "do_not_optimize.mlir", + "parse_print.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/IREE/IR/test/CMakeLists.txt b/iree/compiler/Dialect/IREE/IR/test/CMakeLists.txt index 4b02562ac5b7..34fbf727c1c8 100644 --- a/iree/compiler/Dialect/IREE/IR/test/CMakeLists.txt +++ b/iree/compiler/Dialect/IREE/IR/test/CMakeLists.txt @@ -10,12 +10,13 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "byte_buffer_ops.mlir" + "do_not_optimize.mlir" + "parse_print.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/IREE/Transforms/test/BUILD b/iree/compiler/Dialect/IREE/Transforms/test/BUILD index b780b112a316..061dc7743d8f 100644 --- a/iree/compiler/Dialect/IREE/Transforms/test/BUILD +++ b/iree/compiler/Dialect/IREE/Transforms/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,10 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + ["drop_compiler_hints.mlir"], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/IREE/Transforms/test/CMakeLists.txt b/iree/compiler/Dialect/IREE/Transforms/test/CMakeLists.txt index 1de97a594fc5..3638b362f24d 100644 --- a/iree/compiler/Dialect/IREE/Transforms/test/CMakeLists.txt +++ b/iree/compiler/Dialect/IREE/Transforms/test/CMakeLists.txt @@ -10,12 +10,11 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "drop_compiler_hints.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/Modules/Check/IR/BUILD b/iree/compiler/Dialect/Modules/Check/IR/BUILD index fd35c89b731e..5ca666407258 100644 --- a/iree/compiler/Dialect/Modules/Check/IR/BUILD +++ b/iree/compiler/Dialect/Modules/Check/IR/BUILD @@ -14,6 +14,7 @@ load("//build_tools/bazel:iree_tablegen_doc.bzl", "iree_tablegen_doc") load("//build_tools/bazel:tblgen.bzl", "gentbl") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -23,7 +24,10 @@ package( filegroup( name = "td_files", - srcs = glob(["*.td"]), + srcs = enforce_glob( + ["CheckOps.td"], + include = ["*.td"], + ), ) cc_library( diff --git a/iree/compiler/Dialect/Modules/Check/IR/CMakeLists.txt b/iree/compiler/Dialect/Modules/Check/IR/CMakeLists.txt index 508b269a4400..6eb8f4b5bdd2 100644 --- a/iree/compiler/Dialect/Modules/Check/IR/CMakeLists.txt +++ b/iree/compiler/Dialect/Modules/Check/IR/CMakeLists.txt @@ -10,7 +10,6 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_TD LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.td) iree_cc_library( NAME IR diff --git a/iree/compiler/Dialect/Modules/Check/test/BUILD b/iree/compiler/Dialect/Modules/Check/test/BUILD index bb9c0e0808d7..d931dda42da7 100644 --- a/iree/compiler/Dialect/Modules/Check/test/BUILD +++ b/iree/compiler/Dialect/Modules/Check/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,13 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "canonicalize.mlir", + "ops.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/Modules/Check/test/CMakeLists.txt b/iree/compiler/Dialect/Modules/Check/test/CMakeLists.txt index 90487d98c828..263842a4761b 100644 --- a/iree/compiler/Dialect/Modules/Check/test/CMakeLists.txt +++ b/iree/compiler/Dialect/Modules/Check/test/CMakeLists.txt @@ -10,12 +10,12 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "canonicalize.mlir" + "ops.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/Modules/Strings/Conversion/test/BUILD b/iree/compiler/Dialect/Modules/Strings/Conversion/test/BUILD index bb9c0e0808d7..22b5c8ff4c51 100644 --- a/iree/compiler/Dialect/Modules/Strings/Conversion/test/BUILD +++ b/iree/compiler/Dialect/Modules/Strings/Conversion/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,10 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + ["convert_to_hal.mlir"], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/Modules/Strings/Conversion/test/CMakeLists.txt b/iree/compiler/Dialect/Modules/Strings/Conversion/test/CMakeLists.txt index 3d350faefc84..1d7774acf461 100644 --- a/iree/compiler/Dialect/Modules/Strings/Conversion/test/CMakeLists.txt +++ b/iree/compiler/Dialect/Modules/Strings/Conversion/test/CMakeLists.txt @@ -10,12 +10,11 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "convert_to_hal.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/Modules/Strings/IR/BUILD b/iree/compiler/Dialect/Modules/Strings/IR/BUILD index 6967ede925cc..e5a5fa414514 100644 --- a/iree/compiler/Dialect/Modules/Strings/IR/BUILD +++ b/iree/compiler/Dialect/Modules/Strings/IR/BUILD @@ -14,6 +14,7 @@ load("//build_tools/bazel:iree_tablegen_doc.bzl", "iree_tablegen_doc") load("//build_tools/bazel:tblgen.bzl", "gentbl") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -23,7 +24,10 @@ package( filegroup( name = "td_files", - srcs = glob(["*.td"]), + srcs = enforce_glob( + ["Ops.td"], + include = ["*.td"], + ), ) cc_library( diff --git a/iree/compiler/Dialect/Modules/Strings/IR/CMakeLists.txt b/iree/compiler/Dialect/Modules/Strings/IR/CMakeLists.txt index d310fe234796..636318f6cdb3 100644 --- a/iree/compiler/Dialect/Modules/Strings/IR/CMakeLists.txt +++ b/iree/compiler/Dialect/Modules/Strings/IR/CMakeLists.txt @@ -10,7 +10,6 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_TD LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.td) iree_cc_library( NAME IR diff --git a/iree/compiler/Dialect/Modules/Strings/IR/test/BUILD b/iree/compiler/Dialect/Modules/Strings/IR/test/BUILD index b780b112a316..bfb77fa33c7f 100644 --- a/iree/compiler/Dialect/Modules/Strings/IR/test/BUILD +++ b/iree/compiler/Dialect/Modules/Strings/IR/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,10 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + ["strings_ops.mlir"], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/Modules/Strings/IR/test/CMakeLists.txt b/iree/compiler/Dialect/Modules/Strings/IR/test/CMakeLists.txt index bcb5d45b497a..47e93c248172 100644 --- a/iree/compiler/Dialect/Modules/Strings/IR/test/CMakeLists.txt +++ b/iree/compiler/Dialect/Modules/Strings/IR/test/CMakeLists.txt @@ -10,12 +10,11 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "strings_ops.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/Modules/TensorList/BUILD b/iree/compiler/Dialect/Modules/TensorList/BUILD index 97fbca4dce3e..57e8c589a51a 100644 --- a/iree/compiler/Dialect/Modules/TensorList/BUILD +++ b/iree/compiler/Dialect/Modules/TensorList/BUILD @@ -20,11 +20,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -filegroup( - name = "td_files", - srcs = glob(["*.td"]), -) - cc_embed_data( name = "tensorlist_imports", srcs = ["tensorlist.imports.mlir"], diff --git a/iree/compiler/Dialect/Modules/TensorList/CMakeLists.txt b/iree/compiler/Dialect/Modules/TensorList/CMakeLists.txt index b100495ab132..795076a278bc 100644 --- a/iree/compiler/Dialect/Modules/TensorList/CMakeLists.txt +++ b/iree/compiler/Dialect/Modules/TensorList/CMakeLists.txt @@ -10,7 +10,6 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_TD LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.td) iree_cc_embed_data( NAME tensorlist_imports diff --git a/iree/compiler/Dialect/Modules/TensorList/Conversion/test/BUILD b/iree/compiler/Dialect/Modules/TensorList/Conversion/test/BUILD index b780b112a316..c7f1ae864218 100644 --- a/iree/compiler/Dialect/Modules/TensorList/Conversion/test/BUILD +++ b/iree/compiler/Dialect/Modules/TensorList/Conversion/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,13 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "convert_hal_to_vm.mlir", + "convert_to_hal.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/Modules/TensorList/Conversion/test/CMakeLists.txt b/iree/compiler/Dialect/Modules/TensorList/Conversion/test/CMakeLists.txt index 42faab43624b..7b2121058027 100644 --- a/iree/compiler/Dialect/Modules/TensorList/Conversion/test/CMakeLists.txt +++ b/iree/compiler/Dialect/Modules/TensorList/Conversion/test/CMakeLists.txt @@ -10,12 +10,12 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "convert_hal_to_vm.mlir" + "convert_to_hal.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/Modules/TensorList/IR/BUILD b/iree/compiler/Dialect/Modules/TensorList/IR/BUILD index 02b533055129..ac7a903ad033 100644 --- a/iree/compiler/Dialect/Modules/TensorList/IR/BUILD +++ b/iree/compiler/Dialect/Modules/TensorList/IR/BUILD @@ -14,6 +14,7 @@ load("//build_tools/bazel:iree_tablegen_doc.bzl", "iree_tablegen_doc") load("//build_tools/bazel:tblgen.bzl", "gentbl") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -25,7 +26,13 @@ exports_files(["TensorListBase.td"]) filegroup( name = "td_files", - srcs = glob(["*.td"]), + srcs = enforce_glob( + [ + "TensorListBase.td", + "TensorListOps.td", + ], + include = ["*.td"], + ), ) cc_library( diff --git a/iree/compiler/Dialect/Modules/TensorList/IR/CMakeLists.txt b/iree/compiler/Dialect/Modules/TensorList/IR/CMakeLists.txt index e50e0113ec43..b3f56e7c8658 100644 --- a/iree/compiler/Dialect/Modules/TensorList/IR/CMakeLists.txt +++ b/iree/compiler/Dialect/Modules/TensorList/IR/CMakeLists.txt @@ -10,7 +10,6 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_TD LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.td) iree_cc_library( NAME IR diff --git a/iree/compiler/Dialect/Modules/TensorList/IR/test/BUILD b/iree/compiler/Dialect/Modules/TensorList/IR/test/BUILD index b780b112a316..50dd4d0d66b4 100644 --- a/iree/compiler/Dialect/Modules/TensorList/IR/test/BUILD +++ b/iree/compiler/Dialect/Modules/TensorList/IR/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,10 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + ["ops.mlir"], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/Modules/TensorList/IR/test/CMakeLists.txt b/iree/compiler/Dialect/Modules/TensorList/IR/test/CMakeLists.txt index c75d60e1c64a..f344660c0c92 100644 --- a/iree/compiler/Dialect/Modules/TensorList/IR/test/CMakeLists.txt +++ b/iree/compiler/Dialect/Modules/TensorList/IR/test/CMakeLists.txt @@ -10,12 +10,11 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "ops.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/Sequence/IR/BUILD b/iree/compiler/Dialect/Sequence/IR/BUILD index 430d094390d6..a670f1acc117 100644 --- a/iree/compiler/Dialect/Sequence/IR/BUILD +++ b/iree/compiler/Dialect/Sequence/IR/BUILD @@ -14,6 +14,7 @@ load("//build_tools/bazel:iree_tablegen_doc.bzl", "iree_tablegen_doc") load("//build_tools/bazel:tblgen.bzl", "gentbl") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -23,7 +24,13 @@ package( filegroup( name = "td_files", - srcs = glob(["*.td"]), + srcs = enforce_glob( + [ + "SequenceBase.td", + "SequenceOps.td", + ], + include = ["*.td"], + ), ) cc_library( diff --git a/iree/compiler/Dialect/Sequence/IR/CMakeLists.txt b/iree/compiler/Dialect/Sequence/IR/CMakeLists.txt index cad8c6b252a0..2ce24bd7b4f7 100644 --- a/iree/compiler/Dialect/Sequence/IR/CMakeLists.txt +++ b/iree/compiler/Dialect/Sequence/IR/CMakeLists.txt @@ -10,7 +10,6 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_TD LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.td) iree_cc_library( NAME IR diff --git a/iree/compiler/Dialect/Sequence/IR/test/BUILD b/iree/compiler/Dialect/Sequence/IR/test/BUILD index bb9c0e0808d7..7eb3cff6c0e7 100644 --- a/iree/compiler/Dialect/Sequence/IR/test/BUILD +++ b/iree/compiler/Dialect/Sequence/IR/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,13 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "sequence_map.mlir", + "sequence_of.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/Sequence/IR/test/CMakeLists.txt b/iree/compiler/Dialect/Sequence/IR/test/CMakeLists.txt index 80344af8af68..e4132bbb2f57 100644 --- a/iree/compiler/Dialect/Sequence/IR/test/CMakeLists.txt +++ b/iree/compiler/Dialect/Sequence/IR/test/CMakeLists.txt @@ -10,12 +10,12 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "sequence_map.mlir" + "sequence_of.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/Shape/Conversion/test/BUILD b/iree/compiler/Dialect/Shape/Conversion/test/BUILD index bb9c0e0808d7..c861d29b2047 100644 --- a/iree/compiler/Dialect/Shape/Conversion/test/BUILD +++ b/iree/compiler/Dialect/Shape/Conversion/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,10 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + ["shape_to_shapex.mlir"], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/Shape/Conversion/test/CMakeLists.txt b/iree/compiler/Dialect/Shape/Conversion/test/CMakeLists.txt index b945e8f56445..d22775fc9683 100644 --- a/iree/compiler/Dialect/Shape/Conversion/test/CMakeLists.txt +++ b/iree/compiler/Dialect/Shape/Conversion/test/CMakeLists.txt @@ -10,12 +10,11 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "shape_to_shapex.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/Shape/IR/BUILD b/iree/compiler/Dialect/Shape/IR/BUILD index 86c577cad944..2a0ff9592270 100644 --- a/iree/compiler/Dialect/Shape/IR/BUILD +++ b/iree/compiler/Dialect/Shape/IR/BUILD @@ -14,6 +14,7 @@ load("//build_tools/bazel:iree_tablegen_doc.bzl", "iree_tablegen_doc") load("//build_tools/bazel:tblgen.bzl", "gentbl") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -25,7 +26,13 @@ exports_files(["ShapeBase.td"]) filegroup( name = "td_files", - srcs = glob(["*.td"]), + srcs = enforce_glob( + [ + "ShapeBase.td", + "ShapeOps.td", + ], + include = ["*.td"], + ), ) cc_library( diff --git a/iree/compiler/Dialect/Shape/IR/CMakeLists.txt b/iree/compiler/Dialect/Shape/IR/CMakeLists.txt index ca7179113282..b00454146a0f 100644 --- a/iree/compiler/Dialect/Shape/IR/CMakeLists.txt +++ b/iree/compiler/Dialect/Shape/IR/CMakeLists.txt @@ -10,7 +10,6 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_TD LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.td) iree_cc_library( NAME IR diff --git a/iree/compiler/Dialect/Shape/IR/test/BUILD b/iree/compiler/Dialect/Shape/IR/test/BUILD index b780b112a316..64deb152c406 100644 --- a/iree/compiler/Dialect/Shape/IR/test/BUILD +++ b/iree/compiler/Dialect/Shape/IR/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,15 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "canonicalize.mlir", + "op_verification.mlir", + "parse_print.mlir", + "ranked_shape_type.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/Shape/IR/test/CMakeLists.txt b/iree/compiler/Dialect/Shape/IR/test/CMakeLists.txt index 43c4a40d08a2..fa8db2b4d34f 100644 --- a/iree/compiler/Dialect/Shape/IR/test/CMakeLists.txt +++ b/iree/compiler/Dialect/Shape/IR/test/CMakeLists.txt @@ -10,12 +10,14 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "canonicalize.mlir" + "op_verification.mlir" + "parse_print.mlir" + "ranked_shape_type.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/Shape/Plugins/VMLA/test/BUILD b/iree/compiler/Dialect/Shape/Plugins/VMLA/test/BUILD index bb9c0e0808d7..24fd2874ad5e 100644 --- a/iree/compiler/Dialect/Shape/Plugins/VMLA/test/BUILD +++ b/iree/compiler/Dialect/Shape/Plugins/VMLA/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,10 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + ["custom_ops.mlir"], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/Shape/Plugins/VMLA/test/CMakeLists.txt b/iree/compiler/Dialect/Shape/Plugins/VMLA/test/CMakeLists.txt index e0eb941b7a31..3c6404e90d64 100644 --- a/iree/compiler/Dialect/Shape/Plugins/VMLA/test/CMakeLists.txt +++ b/iree/compiler/Dialect/Shape/Plugins/VMLA/test/CMakeLists.txt @@ -10,12 +10,11 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "custom_ops.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/Shape/Plugins/XLA/test/BUILD b/iree/compiler/Dialect/Shape/Plugins/XLA/test/BUILD index bb9c0e0808d7..24fd2874ad5e 100644 --- a/iree/compiler/Dialect/Shape/Plugins/XLA/test/BUILD +++ b/iree/compiler/Dialect/Shape/Plugins/XLA/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,10 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + ["custom_ops.mlir"], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/Shape/Plugins/XLA/test/CMakeLists.txt b/iree/compiler/Dialect/Shape/Plugins/XLA/test/CMakeLists.txt index 44744ecb6f31..2fd4f816907b 100644 --- a/iree/compiler/Dialect/Shape/Plugins/XLA/test/CMakeLists.txt +++ b/iree/compiler/Dialect/Shape/Plugins/XLA/test/CMakeLists.txt @@ -10,12 +10,11 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "custom_ops.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/Shape/Transforms/test/BUILD b/iree/compiler/Dialect/Shape/Transforms/test/BUILD index bb9c0e0808d7..eb8a1ecfe445 100644 --- a/iree/compiler/Dialect/Shape/Transforms/test/BUILD +++ b/iree/compiler/Dialect/Shape/Transforms/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,18 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "convert_hlo_to_shape_dialect.mlir", + "expand_function_dynamic_dims.mlir", + "expand_function_ranked_shape_dims.mlir", + "hoist_shape_calculations.mlir", + "materialize_shape_calculations.mlir", + "tie_dynamic_shapes.mlir", + "tie_dynamic_shapes_no_recurse.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/Shape/Transforms/test/CMakeLists.txt b/iree/compiler/Dialect/Shape/Transforms/test/CMakeLists.txt index 4acab9242282..6c0229e3efdf 100644 --- a/iree/compiler/Dialect/Shape/Transforms/test/CMakeLists.txt +++ b/iree/compiler/Dialect/Shape/Transforms/test/CMakeLists.txt @@ -10,12 +10,17 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "convert_hlo_to_shape_dialect.mlir" + "expand_function_dynamic_dims.mlir" + "expand_function_ranked_shape_dims.mlir" + "hoist_shape_calculations.mlir" + "materialize_shape_calculations.mlir" + "tie_dynamic_shapes.mlir" + "tie_dynamic_shapes_no_recurse.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/VM/Analysis/test/BUILD b/iree/compiler/Dialect/VM/Analysis/test/BUILD index b780b112a316..a411dd6ffd53 100644 --- a/iree/compiler/Dialect/VM/Analysis/test/BUILD +++ b/iree/compiler/Dialect/VM/Analysis/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,13 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "register_allocation.mlir", + "value_liveness.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/VM/Analysis/test/CMakeLists.txt b/iree/compiler/Dialect/VM/Analysis/test/CMakeLists.txt index cb58a80932b4..2dd28bfa3b50 100644 --- a/iree/compiler/Dialect/VM/Analysis/test/CMakeLists.txt +++ b/iree/compiler/Dialect/VM/Analysis/test/CMakeLists.txt @@ -10,12 +10,12 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "register_allocation.mlir" + "value_liveness.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/VM/Conversion/IREEToVM/test/BUILD b/iree/compiler/Dialect/VM/Conversion/IREEToVM/test/BUILD index bb9c0e0808d7..26a05ad4bb48 100644 --- a/iree/compiler/Dialect/VM/Conversion/IREEToVM/test/BUILD +++ b/iree/compiler/Dialect/VM/Conversion/IREEToVM/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,13 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "byte_buffer_ops.mlir", + "hint_ops.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/VM/Conversion/IREEToVM/test/CMakeLists.txt b/iree/compiler/Dialect/VM/Conversion/IREEToVM/test/CMakeLists.txt index c11e1256d462..81db6954bef8 100644 --- a/iree/compiler/Dialect/VM/Conversion/IREEToVM/test/CMakeLists.txt +++ b/iree/compiler/Dialect/VM/Conversion/IREEToVM/test/CMakeLists.txt @@ -10,12 +10,12 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "byte_buffer_ops.mlir" + "hint_ops.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/BUILD b/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/BUILD index b780b112a316..cd7c2f57edba 100644 --- a/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/BUILD +++ b/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,18 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "arithmetic_ops.mlir", + "assignment_ops.mlir", + "comparison_ops.mlir", + "const_ops.mlir", + "control_flow_ops.mlir", + "func_attrs.mlir", + "structural_ops.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/CMakeLists.txt b/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/CMakeLists.txt index 17e1e27ff562..38aadb785895 100644 --- a/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/CMakeLists.txt +++ b/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/CMakeLists.txt @@ -10,12 +10,17 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "arithmetic_ops.mlir" + "assignment_ops.mlir" + "comparison_ops.mlir" + "const_ops.mlir" + "control_flow_ops.mlir" + "func_attrs.mlir" + "structural_ops.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/VM/IR/BUILD b/iree/compiler/Dialect/VM/IR/BUILD index 5cc616239764..b07c0c5fb013 100644 --- a/iree/compiler/Dialect/VM/IR/BUILD +++ b/iree/compiler/Dialect/VM/IR/BUILD @@ -14,6 +14,7 @@ load("//build_tools/bazel:iree_tablegen_doc.bzl", "iree_tablegen_doc") load("//build_tools/bazel:tblgen.bzl", "gentbl") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -25,7 +26,13 @@ exports_files(["VMOps.td"]) filegroup( name = "td_files", - srcs = glob(["*.td"]), + srcs = enforce_glob( + [ + "VMBase.td", + "VMOps.td", + ], + include = ["*.td"], + ), ) cc_library( diff --git a/iree/compiler/Dialect/VM/IR/CMakeLists.txt b/iree/compiler/Dialect/VM/IR/CMakeLists.txt index 0b1a1784ae6f..e9395ad58104 100644 --- a/iree/compiler/Dialect/VM/IR/CMakeLists.txt +++ b/iree/compiler/Dialect/VM/IR/CMakeLists.txt @@ -10,7 +10,6 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_TD LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.td) iree_cc_library( NAME IR diff --git a/iree/compiler/Dialect/VM/IR/test/BUILD b/iree/compiler/Dialect/VM/IR/test/BUILD index b780b112a316..76cf62ebf88f 100644 --- a/iree/compiler/Dialect/VM/IR/test/BUILD +++ b/iree/compiler/Dialect/VM/IR/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,31 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "arithmetic_folding.mlir", + "arithmetic_ops.mlir", + "assignment_folding.mlir", + "assignment_ops.mlir", + "comparison_folding.mlir", + "comparison_ops.mlir", + "const_folding.mlir", + "const_ops.mlir", + "control_flow_folding.mlir", + "control_flow_ops.mlir", + "conversion_folding.mlir", + "conversion_ops.mlir", + "debug_folding.mlir", + "debug_ops.mlir", + "global_folding.mlir", + "global_ops.mlir", + "list_op_verification.mlir", + "list_ops.mlir", + "shift_ops.mlir", + "structural_ops.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/VM/IR/test/CMakeLists.txt b/iree/compiler/Dialect/VM/IR/test/CMakeLists.txt index 66b2364af296..2604211f754e 100644 --- a/iree/compiler/Dialect/VM/IR/test/CMakeLists.txt +++ b/iree/compiler/Dialect/VM/IR/test/CMakeLists.txt @@ -10,12 +10,30 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "arithmetic_folding.mlir" + "arithmetic_ops.mlir" + "assignment_folding.mlir" + "assignment_ops.mlir" + "comparison_folding.mlir" + "comparison_ops.mlir" + "const_folding.mlir" + "const_ops.mlir" + "control_flow_folding.mlir" + "control_flow_ops.mlir" + "conversion_folding.mlir" + "conversion_ops.mlir" + "debug_folding.mlir" + "debug_ops.mlir" + "global_folding.mlir" + "global_ops.mlir" + "list_op_verification.mlir" + "list_ops.mlir" + "shift_ops.mlir" + "structural_ops.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/VM/Target/Bytecode/test/BUILD b/iree/compiler/Dialect/VM/Target/Bytecode/test/BUILD index 71d362403828..3e005972478b 100644 --- a/iree/compiler/Dialect/VM/Target/Bytecode/test/BUILD +++ b/iree/compiler/Dialect/VM/Target/Bytecode/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,14 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "constant_encoding.mlir", + "module_encoding_smoke.mlir", + "reflection_attrs.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-translate", diff --git a/iree/compiler/Dialect/VM/Target/Bytecode/test/CMakeLists.txt b/iree/compiler/Dialect/VM/Target/Bytecode/test/CMakeLists.txt index 56dcc5835ca9..287b67d8bbfd 100644 --- a/iree/compiler/Dialect/VM/Target/Bytecode/test/CMakeLists.txt +++ b/iree/compiler/Dialect/VM/Target/Bytecode/test/CMakeLists.txt @@ -10,12 +10,13 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "constant_encoding.mlir" + "module_encoding_smoke.mlir" + "reflection_attrs.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-translate diff --git a/iree/compiler/Dialect/VM/Transforms/test/BUILD b/iree/compiler/Dialect/VM/Transforms/test/BUILD index b780b112a316..6ebc889abf13 100644 --- a/iree/compiler/Dialect/VM/Transforms/test/BUILD +++ b/iree/compiler/Dialect/VM/Transforms/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,16 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "global_initialization.mlir", + "hoist_inlined_rodata.mlir", + "mark_public_symbols_exported.mlir", + "ordinal_allocation.mlir", + "sink_defining_ops.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/VM/Transforms/test/CMakeLists.txt b/iree/compiler/Dialect/VM/Transforms/test/CMakeLists.txt index 8c85b404f1cf..dcae639e29bc 100644 --- a/iree/compiler/Dialect/VM/Transforms/test/CMakeLists.txt +++ b/iree/compiler/Dialect/VM/Transforms/test/CMakeLists.txt @@ -10,12 +10,15 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "global_initialization.mlir" + "hoist_inlined_rodata.mlir" + "mark_public_symbols_exported.mlir" + "ordinal_allocation.mlir" + "sink_defining_ops.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/test/BUILD b/iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/test/BUILD index bb9c0e0808d7..40dc09d52b4e 100644 --- a/iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/test/BUILD +++ b/iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,10 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + ["interface_ops.mlir"], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/test/CMakeLists.txt b/iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/test/CMakeLists.txt index 096adae6ff75..42cc9b2bc342 100644 --- a/iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/test/CMakeLists.txt +++ b/iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/test/CMakeLists.txt @@ -10,12 +10,11 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "interface_ops.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/BUILD b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/BUILD index bb9c0e0808d7..3cbab8b24a2d 100644 --- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/BUILD +++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,25 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "broadcast_in_dim.mlir", + "concatenate.mlir", + "conv.mlir", + "convert.mlir", + "dynamic_slice.mlir", + "fft.mlir", + "math_ops.mlir", + "reduce.mlir", + "reduce_window.mlir", + "reshape.mlir", + "scatter.mlir", + "slice.mlir", + "sort.mlir", + "transpose.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/CMakeLists.txt b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/CMakeLists.txt index 8e0086f90cf6..fbd6949426ce 100644 --- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/CMakeLists.txt +++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/CMakeLists.txt @@ -10,12 +10,24 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "broadcast_in_dim.mlir" + "concatenate.mlir" + "conv.mlir" + "convert.mlir" + "dynamic_slice.mlir" + "fft.mlir" + "math_ops.mlir" + "reduce.mlir" + "reduce_window.mlir" + "reshape.mlir" + "scatter.mlir" + "slice.mlir" + "sort.mlir" + "transpose.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/test/BUILD b/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/test/BUILD index bb9c0e0808d7..80a7bd483590 100644 --- a/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/test/BUILD +++ b/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,14 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "comparison_ops.mlir", + "constant_ops.mlir", + "math_ops.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/test/CMakeLists.txt b/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/test/CMakeLists.txt index 4b5dad785f34..ed4a341301fe 100644 --- a/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/test/CMakeLists.txt +++ b/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/test/CMakeLists.txt @@ -10,12 +10,13 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "comparison_ops.mlir" + "constant_ops.mlir" + "math_ops.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/test/BUILD b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/test/BUILD index bb9c0e0808d7..368d1027565b 100644 --- a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/test/BUILD +++ b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,13 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "constant_ops.mlir", + "conversion.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/test/CMakeLists.txt b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/test/CMakeLists.txt index 06fa3fd0ae00..e60fbfeca624 100644 --- a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/test/CMakeLists.txt +++ b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/test/CMakeLists.txt @@ -10,12 +10,12 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "constant_ops.mlir" + "conversion.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/VMLA/IR/BUILD b/iree/compiler/Dialect/VMLA/IR/BUILD index 40e8973cc3e0..365d43e150e4 100644 --- a/iree/compiler/Dialect/VMLA/IR/BUILD +++ b/iree/compiler/Dialect/VMLA/IR/BUILD @@ -14,6 +14,7 @@ load("//build_tools/bazel:iree_tablegen_doc.bzl", "iree_tablegen_doc") load("//build_tools/bazel:tblgen.bzl", "gentbl") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -25,7 +26,13 @@ exports_files(["VMLAOps.td"]) filegroup( name = "td_files", - srcs = glob(["*.td"]), + srcs = enforce_glob( + [ + "VMLABase.td", + "VMLAOps.td", + ], + include = ["*.td"], + ), ) cc_library( diff --git a/iree/compiler/Dialect/VMLA/IR/CMakeLists.txt b/iree/compiler/Dialect/VMLA/IR/CMakeLists.txt index c35d99cf7576..c33ab9f72501 100644 --- a/iree/compiler/Dialect/VMLA/IR/CMakeLists.txt +++ b/iree/compiler/Dialect/VMLA/IR/CMakeLists.txt @@ -10,7 +10,6 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_TD LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.td) iree_cc_library( NAME IR diff --git a/iree/compiler/Dialect/VMLA/IR/test/BUILD b/iree/compiler/Dialect/VMLA/IR/test/BUILD index bb9c0e0808d7..228bc609241e 100644 --- a/iree/compiler/Dialect/VMLA/IR/test/BUILD +++ b/iree/compiler/Dialect/VMLA/IR/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,15 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "buffer_ops.mlir", + "conv_reduction_ops.mlir", + "general_ops.mlir", + "shape_structure_ops.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/VMLA/IR/test/CMakeLists.txt b/iree/compiler/Dialect/VMLA/IR/test/CMakeLists.txt index c4a637329d89..a392ee0f426c 100644 --- a/iree/compiler/Dialect/VMLA/IR/test/CMakeLists.txt +++ b/iree/compiler/Dialect/VMLA/IR/test/CMakeLists.txt @@ -10,12 +10,14 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "buffer_ops.mlir" + "conv_reduction_ops.mlir" + "general_ops.mlir" + "shape_structure_ops.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/VMLA/Transforms/test/BUILD b/iree/compiler/Dialect/VMLA/Transforms/test/BUILD index bb9c0e0808d7..8e124ba2e665 100644 --- a/iree/compiler/Dialect/VMLA/Transforms/test/BUILD +++ b/iree/compiler/Dialect/VMLA/Transforms/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,14 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "pre_conversion_lowering.mlir", + "transformation.mlir", + "unroll_reductions.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/VMLA/Transforms/test/CMakeLists.txt b/iree/compiler/Dialect/VMLA/Transforms/test/CMakeLists.txt index 94ba2219ef08..4a84f8595d0c 100644 --- a/iree/compiler/Dialect/VMLA/Transforms/test/CMakeLists.txt +++ b/iree/compiler/Dialect/VMLA/Transforms/test/CMakeLists.txt @@ -10,12 +10,13 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "pre_conversion_lowering.mlir" + "transformation.mlir" + "unroll_reductions.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/Vulkan/IR/BUILD b/iree/compiler/Dialect/Vulkan/IR/BUILD index 260521d56de0..9b9220b008be 100644 --- a/iree/compiler/Dialect/Vulkan/IR/BUILD +++ b/iree/compiler/Dialect/Vulkan/IR/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//build_tools/bazel:tblgen.bzl", "gentbl") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,13 @@ package( filegroup( name = "td_files", - srcs = glob(["*.td"]), + srcs = enforce_glob( + [ + "VulkanAttributes.td", + "VulkanBase.td", + ], + include = ["*.td"], + ), ) cc_library( diff --git a/iree/compiler/Dialect/Vulkan/IR/CMakeLists.txt b/iree/compiler/Dialect/Vulkan/IR/CMakeLists.txt index 06bdf83818f6..0a03e5f98852 100644 --- a/iree/compiler/Dialect/Vulkan/IR/CMakeLists.txt +++ b/iree/compiler/Dialect/Vulkan/IR/CMakeLists.txt @@ -10,7 +10,6 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_TD LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.td) iree_cc_library( NAME IR diff --git a/iree/compiler/Dialect/Vulkan/IR/test/BUILD b/iree/compiler/Dialect/Vulkan/IR/test/BUILD index bb9c0e0808d7..37d4ac5967c0 100644 --- a/iree/compiler/Dialect/Vulkan/IR/test/BUILD +++ b/iree/compiler/Dialect/Vulkan/IR/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,10 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + ["target_env.mlir"], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/Vulkan/IR/test/CMakeLists.txt b/iree/compiler/Dialect/Vulkan/IR/test/CMakeLists.txt index ddec28513220..3d87e77b7e82 100644 --- a/iree/compiler/Dialect/Vulkan/IR/test/CMakeLists.txt +++ b/iree/compiler/Dialect/Vulkan/IR/test/CMakeLists.txt @@ -10,12 +10,11 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "target_env.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/Vulkan/Utils/test/BUILD b/iree/compiler/Dialect/Vulkan/Utils/test/BUILD index 29aa3a5b0bf8..0b3152403c8f 100644 --- a/iree/compiler/Dialect/Vulkan/Utils/test/BUILD +++ b/iree/compiler/Dialect/Vulkan/Utils/test/BUILD @@ -14,6 +14,7 @@ load("//iree:build_defs.oss.bzl", "iree_cmake_extra_content") load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -31,7 +32,10 @@ endif() iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + ["target_env_conversion.mlir"], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/Vulkan/Utils/test/CMakeLists.txt b/iree/compiler/Dialect/Vulkan/Utils/test/CMakeLists.txt index 2dd596b2b2cb..52386a414e2d 100644 --- a/iree/compiler/Dialect/Vulkan/Utils/test/CMakeLists.txt +++ b/iree/compiler/Dialect/Vulkan/Utils/test/CMakeLists.txt @@ -14,12 +14,11 @@ endif() iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "target_env_conversion.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Translation/test/BUILD b/iree/compiler/Translation/test/BUILD index cac3e25116fd..ed1d5ed68d44 100644 --- a/iree/compiler/Translation/test/BUILD +++ b/iree/compiler/Translation/test/BUILD @@ -15,6 +15,7 @@ # Tests for common transforms. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -24,7 +25,13 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "do_not_optimize.mlir", + "smoketest.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Translation/test/CMakeLists.txt b/iree/compiler/Translation/test/CMakeLists.txt index 1e21a2f0f37d..ef4fb17852e9 100644 --- a/iree/compiler/Translation/test/CMakeLists.txt +++ b/iree/compiler/Translation/test/CMakeLists.txt @@ -10,12 +10,12 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "do_not_optimize.mlir" + "smoketest.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/modules/check/test/BUILD b/iree/modules/check/test/BUILD index 162109cf2862..a9205903ab4b 100644 --- a/iree/modules/check/test/BUILD +++ b/iree/modules/check/test/BUILD @@ -14,6 +14,7 @@ load("//build_tools/bazel:iree_check_test.bzl", "iree_check_test_suite") load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -23,7 +24,13 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "failure.mlir", + "success.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-check-module", diff --git a/iree/modules/check/test/CMakeLists.txt b/iree/modules/check/test/CMakeLists.txt index 48a8afe6cde8..943942155d25 100644 --- a/iree/modules/check/test/CMakeLists.txt +++ b/iree/modules/check/test/CMakeLists.txt @@ -10,12 +10,12 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "failure.mlir" + "success.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-check-module diff --git a/iree/samples/custom_modules/dialect/BUILD b/iree/samples/custom_modules/dialect/BUILD index fa0bc78cb04c..19eac520704c 100644 --- a/iree/samples/custom_modules/dialect/BUILD +++ b/iree/samples/custom_modules/dialect/BUILD @@ -14,6 +14,7 @@ load("//build_tools/embed_data:build_defs.bzl", "cc_embed_data") load("//build_tools/bazel:tblgen.bzl", "gentbl") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -23,7 +24,10 @@ package( filegroup( name = "td_files", - srcs = glob(["*.td"]), + srcs = enforce_glob( + ["custom_ops.td"], + include = ["*.td"], + ), ) cc_library( diff --git a/iree/samples/custom_modules/dialect/CMakeLists.txt b/iree/samples/custom_modules/dialect/CMakeLists.txt index 01ea23d55de1..758d9fa3a1d5 100644 --- a/iree/samples/custom_modules/dialect/CMakeLists.txt +++ b/iree/samples/custom_modules/dialect/CMakeLists.txt @@ -10,7 +10,6 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_TD LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.td) iree_cc_library( NAME dialect diff --git a/iree/samples/custom_modules/dialect/test/BUILD b/iree/samples/custom_modules/dialect/test/BUILD index d0637066cfe4..98133e326e9a 100644 --- a/iree/samples/custom_modules/dialect/test/BUILD +++ b/iree/samples/custom_modules/dialect/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,18 +23,15 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "conversion.mlir", + "custom_ops.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/samples/custom_modules/dialect:custom-opt", "//iree/tools:IreeFileCheck", ], ) -# load("//iree:lit_test.bzl", "iree_lit_test_suite") -# iree_lit_test_suite( -# name = "lit", -# srcs = glob(["*.mlir"]), -# data = [ -# "//iree/tools:IreeFileCheck", -# "//iree/samples/custom_modules/dialect:custom-opt", -# ], -# ) diff --git a/iree/samples/custom_modules/dialect/test/CMakeLists.txt b/iree/samples/custom_modules/dialect/test/CMakeLists.txt index f3b9b6c25b06..e8aaaa50cbfe 100644 --- a/iree/samples/custom_modules/dialect/test/CMakeLists.txt +++ b/iree/samples/custom_modules/dialect/test/CMakeLists.txt @@ -10,12 +10,12 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "conversion.mlir" + "custom_ops.mlir" DATA iree::samples::custom_modules::dialect::custom-opt iree::tools::IreeFileCheck diff --git a/iree/test/e2e/hackability/BUILD b/iree/test/e2e/hackability/BUILD index 361d3dfd8f8e..ffaf264c6c51 100644 --- a/iree/test/e2e/hackability/BUILD +++ b/iree/test/e2e/hackability/BUILD @@ -17,6 +17,7 @@ # those tests in iree/tools/test/ load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -26,7 +27,10 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + ["flow_partitioned.mlir"], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-run-mlir", diff --git a/iree/test/e2e/hackability/CMakeLists.txt b/iree/test/e2e/hackability/CMakeLists.txt index e03613370fa8..928042ac1fbb 100644 --- a/iree/test/e2e/hackability/CMakeLists.txt +++ b/iree/test/e2e/hackability/CMakeLists.txt @@ -10,12 +10,11 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "flow_partitioned.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-run-mlir diff --git a/iree/test/e2e/linalg_tensor_ops/BUILD b/iree/test/e2e/linalg_tensor_ops/BUILD index c580c9785072..df5113f3b8d2 100644 --- a/iree/test/e2e/linalg_tensor_ops/BUILD +++ b/iree/test/e2e/linalg_tensor_ops/BUILD @@ -18,6 +18,7 @@ # written using the IREE Check framework and should always pass on the reference VMLA backend. # See https://google.github.io/iree/TestingGuide#iree-core-end-to-end-tests. +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") load("//build_tools/bazel:iree_check_test.bzl", "iree_check_single_backend_test_suite") package( @@ -26,9 +27,17 @@ package( licenses = ["notice"], # Apache 2.0 ) +ALL_SRCS = enforce_glob( + [ + "add.mlir", + "matmul.mlir", + ], + include = ["*.mlir"], +) + iree_check_single_backend_test_suite( name = "check_llvm-ir_llvm", - srcs = glob(["*.mlir"]), + srcs = ALL_SRCS, compiler_flags = [ "-iree-flow-dispatch-linalg-on-tensors", "-iree-codegen-llvm-experimental-linalg-on-tensors", @@ -39,7 +48,7 @@ iree_check_single_backend_test_suite( iree_check_single_backend_test_suite( name = "check_vulkan-spirv_vulkan", - srcs = glob(["*.mlir"]), + srcs = ALL_SRCS, compiler_flags = [ "-iree-flow-dispatch-linalg-on-tensors", "-iree-codegen-spirv-experimental-linalg-on-tensors", diff --git a/iree/test/e2e/linalg_tensor_ops/CMakeLists.txt b/iree/test/e2e/linalg_tensor_ops/CMakeLists.txt index 642e8e2aa45b..f60e17c17766 100644 --- a/iree/test/e2e/linalg_tensor_ops/CMakeLists.txt +++ b/iree/test/e2e/linalg_tensor_ops/CMakeLists.txt @@ -10,12 +10,12 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_check_single_backend_test_suite( NAME check_llvm-ir_llvm SRCS - "${_GLOB_X_MLIR}" + "add.mlir" + "matmul.mlir" TARGET_BACKEND "dylib-llvm-aot" DRIVER @@ -25,12 +25,12 @@ iree_check_single_backend_test_suite( "-iree-codegen-llvm-experimental-linalg-on-tensors" ) -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_check_single_backend_test_suite( NAME check_vulkan-spirv_vulkan SRCS - "${_GLOB_X_MLIR}" + "add.mlir" + "matmul.mlir" TARGET_BACKEND "vulkan-spirv" DRIVER diff --git a/iree/test/e2e/models/BUILD b/iree/test/e2e/models/BUILD index 1acc650f87d6..0d3b5fc0c9cd 100644 --- a/iree/test/e2e/models/BUILD +++ b/iree/test/e2e/models/BUILD @@ -14,6 +14,7 @@ # Tests for end-to-end IREE support of entire models or their close derivatives. +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") load("//iree:lit_test.bzl", "iree_lit_test_suite") load("//build_tools/bazel:iree_check_test.bzl", "iree_check_single_backend_test_suite") @@ -28,8 +29,18 @@ CHECK_FRAMEWORK_TESTS = ["bert_encoder_unrolled_fake_weights.mlir"] iree_lit_test_suite( name = "lit", size = "medium", - srcs = glob( - ["*.mlir"], + srcs = enforce_glob( + [ + "collatz.mlir", + "edge_detection.mlir", + "fragment_000.mlir", + "fullyconnected.mlir", + "mnist_fake_weights.mlir", + "resnet_fake_weights.mlir", + "unidirectional_lstm.mlir", + ], + include = + ["*.mlir"], exclude = CHECK_FRAMEWORK_TESTS, ), data = [ diff --git a/iree/test/e2e/models/CMakeLists.txt b/iree/test/e2e/models/CMakeLists.txt index 03ca925612c5..4fd5d28b364f 100644 --- a/iree/test/e2e/models/CMakeLists.txt +++ b/iree/test/e2e/models/CMakeLists.txt @@ -10,14 +10,17 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) -file(GLOB _GLOB_BERT_ENCODER_UNROLLED_FAKE_WEIGHTS_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS bert_encoder_unrolled_fake_weights.mlir) -list(REMOVE_ITEM _GLOB_X_MLIR ${_GLOB_BERT_ENCODER_UNROLLED_FAKE_WEIGHTS_MLIR}) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "collatz.mlir" + "edge_detection.mlir" + "fragment_000.mlir" + "fullyconnected.mlir" + "mnist_fake_weights.mlir" + "resnet_fake_weights.mlir" + "unidirectional_lstm.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-run-mlir diff --git a/iree/test/e2e/regression/BUILD b/iree/test/e2e/regression/BUILD index 8e23925ae6fd..74f00d8755eb 100644 --- a/iree/test/e2e/regression/BUILD +++ b/iree/test/e2e/regression/BUILD @@ -16,6 +16,7 @@ # These should focus on support by IREE itself, not for issues with specific runner tools. Place # those tests in iree/tools/test/ +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") load("//iree:lit_test.bzl", "iree_lit_test_suite") package( @@ -26,8 +27,25 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob( - ["*.mlir"], + srcs = enforce_glob( + [ + "dynamic_abs.mlir", + "dynamic_add.mlir", + "dynamic_compare_and_select.mlir", + "dynamic_dot.mlir", + "dynamic_dot_general.mlir", + "dynamic_torch_index_select_high_rank.mlir", + "dynamic_torch_index_select_negative.mlir", + "dynamic_torch_index_select_scalar.mlir", + "dynamic_torch_index_select_vector.mlir", + "executable_benchmark.mlir", + "globals.mlir", + "scalar.mlir", + "trace_dispatch_tensors.mlir", + "unused_args.mlir", + ], + include = + ["*.mlir"], # Disabled temporarily. See GH Issue #4733 exclude = [ "dynamic_linalg_matmul_on_tensors.mlir", diff --git a/iree/test/e2e/regression/CMakeLists.txt b/iree/test/e2e/regression/CMakeLists.txt index dbc8f464baf7..bd998946879f 100644 --- a/iree/test/e2e/regression/CMakeLists.txt +++ b/iree/test/e2e/regression/CMakeLists.txt @@ -10,20 +10,24 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) -file(GLOB _GLOB_DYNAMIC_LINALG_MATMUL_ON_TENSORS_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS dynamic_linalg_matmul_on_tensors.mlir) -list(REMOVE_ITEM _GLOB_X_MLIR ${_GLOB_DYNAMIC_LINALG_MATMUL_ON_TENSORS_MLIR}) -file(GLOB _GLOB_DYNAMIC_LINALG_MATMUL_ON_TENSORS_FUSE_0_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS dynamic_linalg_matmul_on_tensors_fuse_0.mlir) -list(REMOVE_ITEM _GLOB_X_MLIR ${_GLOB_DYNAMIC_LINALG_MATMUL_ON_TENSORS_FUSE_0_MLIR}) -file(GLOB _GLOB_DYNAMIC_LINALG_MATMUL_ON_TENSORS_FUSE_1_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS dynamic_linalg_matmul_on_tensors_fuse_1.mlir) -list(REMOVE_ITEM _GLOB_X_MLIR ${_GLOB_DYNAMIC_LINALG_MATMUL_ON_TENSORS_FUSE_1_MLIR}) -file(GLOB _GLOB_DYNAMIC_LINALG_MATMUL_ON_TENSORS_FUSE_2_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS dynamic_linalg_matmul_on_tensors_fuse_2.mlir) -list(REMOVE_ITEM _GLOB_X_MLIR ${_GLOB_DYNAMIC_LINALG_MATMUL_ON_TENSORS_FUSE_2_MLIR}) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "dynamic_abs.mlir" + "dynamic_add.mlir" + "dynamic_compare_and_select.mlir" + "dynamic_dot.mlir" + "dynamic_dot_general.mlir" + "dynamic_torch_index_select_high_rank.mlir" + "dynamic_torch_index_select_negative.mlir" + "dynamic_torch_index_select_scalar.mlir" + "dynamic_torch_index_select_vector.mlir" + "executable_benchmark.mlir" + "globals.mlir" + "scalar.mlir" + "trace_dispatch_tensors.mlir" + "unused_args.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-benchmark-module diff --git a/iree/test/e2e/structural/BUILD b/iree/test/e2e/structural/BUILD index 9ffecf240ada..09585a71f698 100644 --- a/iree/test/e2e/structural/BUILD +++ b/iree/test/e2e/structural/BUILD @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") load("//build_tools/bazel:iree_check_test.bzl", "iree_check_single_backend_test_suite") package( @@ -23,7 +24,15 @@ package( # TODO(#2395): Enable all the tests for both LLVM and SPIR-V. iree_check_single_backend_test_suite( name = "check_vmla_vmla", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "fused_dispatch_region.mlir", + "gather_add.mlir", + "gather_concat.mlir", + "matmul_add.mlir", + ], + include = ["*.mlir"], + ), driver = "vmla", target_backend = "vmla", ) diff --git a/iree/test/e2e/structural/CMakeLists.txt b/iree/test/e2e/structural/CMakeLists.txt index d8be71771d13..37a2f9042a52 100644 --- a/iree/test/e2e/structural/CMakeLists.txt +++ b/iree/test/e2e/structural/CMakeLists.txt @@ -10,12 +10,14 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_check_single_backend_test_suite( NAME check_vmla_vmla SRCS - "${_GLOB_X_MLIR}" + "fused_dispatch_region.mlir" + "gather_add.mlir" + "gather_concat.mlir" + "matmul_add.mlir" TARGET_BACKEND "vmla" DRIVER diff --git a/iree/test/e2e/tosa_ops/BUILD b/iree/test/e2e/tosa_ops/BUILD index 5eecc23ad5c0..cb3a05392252 100644 --- a/iree/test/e2e/tosa_ops/BUILD +++ b/iree/test/e2e/tosa_ops/BUILD @@ -18,6 +18,7 @@ # written using the IREE Check framework. # See https://google.github.io/iree/developing-iree/testing-guide#iree-core-end-to-end-tests. +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") load("//build_tools/bazel:iree_check_test.bzl", "iree_check_single_backend_test_suite") package( @@ -26,16 +27,50 @@ package( licenses = ["notice"], # Apache 2.0 ) +ALL_SRCS = enforce_glob( + [ + "abs.mlir", + "add.mlir", + "bitwise_and.mlir", + "bitwise_or.mlir", + "bitwise_xor.mlir", + "ceil.mlir", + "clamp.mlir", + "const.mlir", + "exp.mlir", + "floor.mlir", + "greater.mlir", + "greater_equal.mlir", + "if.mlir", + "log.mlir", + "logical_left_shift.mlir", + "logical_right_shift.mlir", + "maximum.mlir", + "minimum.mlir", + "mul.mlir", + "negate.mlir", + "reluN.mlir", + "reshape.mlir", + "rsqrt.mlir", + "select.mlir", + "sub.mlir", + "tanh.mlir", + "transpose.mlir", + "while.mlir", + ], + include = ["*.mlir"], +) + iree_check_single_backend_test_suite( name = "check_vulkan-spirv_vulkan", - srcs = glob(["*.mlir"]), + srcs = ALL_SRCS, driver = "vulkan", target_backend = "vulkan-spirv", ) iree_check_single_backend_test_suite( name = "check_dylib-llvm-aot_dylib", - srcs = glob(["*.mlir"]), + srcs = ALL_SRCS, driver = "dylib", target_backend = "dylib-llvm-aot", ) diff --git a/iree/test/e2e/tosa_ops/CMakeLists.txt b/iree/test/e2e/tosa_ops/CMakeLists.txt index 71543d37f47e..3948a95a51dd 100644 --- a/iree/test/e2e/tosa_ops/CMakeLists.txt +++ b/iree/test/e2e/tosa_ops/CMakeLists.txt @@ -10,24 +10,76 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_check_single_backend_test_suite( NAME check_vulkan-spirv_vulkan SRCS - "${_GLOB_X_MLIR}" + "abs.mlir" + "add.mlir" + "bitwise_and.mlir" + "bitwise_or.mlir" + "bitwise_xor.mlir" + "ceil.mlir" + "clamp.mlir" + "const.mlir" + "exp.mlir" + "floor.mlir" + "greater.mlir" + "greater_equal.mlir" + "if.mlir" + "log.mlir" + "logical_left_shift.mlir" + "logical_right_shift.mlir" + "maximum.mlir" + "minimum.mlir" + "mul.mlir" + "negate.mlir" + "reluN.mlir" + "reshape.mlir" + "rsqrt.mlir" + "select.mlir" + "sub.mlir" + "tanh.mlir" + "transpose.mlir" + "while.mlir" TARGET_BACKEND "vulkan-spirv" DRIVER "vulkan" ) -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_check_single_backend_test_suite( NAME check_dylib-llvm-aot_dylib SRCS - "${_GLOB_X_MLIR}" + "abs.mlir" + "add.mlir" + "bitwise_and.mlir" + "bitwise_or.mlir" + "bitwise_xor.mlir" + "ceil.mlir" + "clamp.mlir" + "const.mlir" + "exp.mlir" + "floor.mlir" + "greater.mlir" + "greater_equal.mlir" + "if.mlir" + "log.mlir" + "logical_left_shift.mlir" + "logical_right_shift.mlir" + "maximum.mlir" + "minimum.mlir" + "mul.mlir" + "negate.mlir" + "reluN.mlir" + "reshape.mlir" + "rsqrt.mlir" + "select.mlir" + "sub.mlir" + "tanh.mlir" + "transpose.mlir" + "while.mlir" TARGET_BACKEND "dylib-llvm-aot" DRIVER diff --git a/iree/test/e2e/vulkan_specific/BUILD b/iree/test/e2e/vulkan_specific/BUILD index b3408042ebf1..689361b25a31 100644 --- a/iree/test/e2e/vulkan_specific/BUILD +++ b/iree/test/e2e/vulkan_specific/BUILD @@ -15,6 +15,7 @@ # Tests for end-to-end IREE support specific to the vulkan-spirv lowering. # TODO(ravishankarm): Reorganize these tests. +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") load("//build_tools/bazel:iree_check_test.bzl", "iree_check_single_backend_test_suite") package( @@ -25,8 +26,17 @@ package( iree_check_single_backend_test_suite( name = "check_vulkan-spirv_vulkan", - srcs = glob( - ["*.mlir"], + srcs = enforce_glob( + [ + "compare.mlir", + "conv.mlir", + "dot_general.mlir", + "log_plus_one.mlir", + "pw_add_multiwg.mlir", + "reduce.mlir", + "vectorized_conv.mlir", + ], + include = ["*.mlir"], exclude = [ "gemm.mlir", ], diff --git a/iree/test/e2e/vulkan_specific/CMakeLists.txt b/iree/test/e2e/vulkan_specific/CMakeLists.txt index 53d6be831941..f476cdaaf052 100644 --- a/iree/test/e2e/vulkan_specific/CMakeLists.txt +++ b/iree/test/e2e/vulkan_specific/CMakeLists.txt @@ -10,14 +10,17 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) -file(GLOB _GLOB_GEMM_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS gemm.mlir) -list(REMOVE_ITEM _GLOB_X_MLIR ${_GLOB_GEMM_MLIR}) iree_check_single_backend_test_suite( NAME check_vulkan-spirv_vulkan SRCS - "${_GLOB_X_MLIR}" + "compare.mlir" + "conv.mlir" + "dot_general.mlir" + "log_plus_one.mlir" + "pw_add_multiwg.mlir" + "reduce.mlir" + "vectorized_conv.mlir" TARGET_BACKEND "vulkan-spirv" DRIVER diff --git a/iree/test/e2e/xla_ops/BUILD b/iree/test/e2e/xla_ops/BUILD index bd6eff8b6d2a..377b40c5a9e0 100644 --- a/iree/test/e2e/xla_ops/BUILD +++ b/iree/test/e2e/xla_ops/BUILD @@ -19,6 +19,7 @@ # See https://google.github.io/iree/TestingGuide#iree-core-end-to-end-tests. load("//build_tools/bazel:iree_check_test.bzl", "iree_check_single_backend_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -28,7 +29,57 @@ package( iree_check_single_backend_test_suite( name = "check_vmla_vmla", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "abs.mlir", + "add.mlir", + "batch_norm_inference.mlir", + "broadcast.mlir", + "broadcast_add.mlir", + "broadcast_in_dim.mlir", + "clamp.mlir", + "compare.mlir", + "concatenate.mlir", + "constant.mlir", + "convert.mlir", + "convolution.mlir", + "cosine.mlir", + "divide.mlir", + "dot.mlir", + "dot_general.mlir", + "exponential.mlir", + "exponential_minus_one.mlir", + "finite.mlir", + "floor.mlir", + "gather.mlir", + "iota.mlir", + "log.mlir", + "log_plus_one.mlir", + "maximum.mlir", + "minimum.mlir", + "multiply.mlir", + "negate.mlir", + "pad.mlir", + "reduce.mlir", + "reduce_window.mlir", + "remainder.mlir", + "reshape.mlir", + "reverse.mlir", + "round.mlir", + "rsqrt.mlir", + "select.mlir", + "sine.mlir", + "slice.mlir", + "sort.mlir", + "sqrt.mlir", + "subtract.mlir", + "tanh.mlir", + "torch_index_select.mlir", + "transpose.mlir", + "while.mlir", + ], + include = ["*.mlir"], + ), driver = "vmla", target_backend = "vmla", ) diff --git a/iree/test/e2e/xla_ops/CMakeLists.txt b/iree/test/e2e/xla_ops/CMakeLists.txt index 9a413975afb2..b5bdab60e4cb 100644 --- a/iree/test/e2e/xla_ops/CMakeLists.txt +++ b/iree/test/e2e/xla_ops/CMakeLists.txt @@ -10,12 +10,56 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_check_single_backend_test_suite( NAME check_vmla_vmla SRCS - "${_GLOB_X_MLIR}" + "abs.mlir" + "add.mlir" + "batch_norm_inference.mlir" + "broadcast.mlir" + "broadcast_add.mlir" + "broadcast_in_dim.mlir" + "clamp.mlir" + "compare.mlir" + "concatenate.mlir" + "constant.mlir" + "convert.mlir" + "convolution.mlir" + "cosine.mlir" + "divide.mlir" + "dot.mlir" + "dot_general.mlir" + "exponential.mlir" + "exponential_minus_one.mlir" + "finite.mlir" + "floor.mlir" + "gather.mlir" + "iota.mlir" + "log.mlir" + "log_plus_one.mlir" + "maximum.mlir" + "minimum.mlir" + "multiply.mlir" + "negate.mlir" + "pad.mlir" + "reduce.mlir" + "reduce_window.mlir" + "remainder.mlir" + "reshape.mlir" + "reverse.mlir" + "round.mlir" + "rsqrt.mlir" + "select.mlir" + "sine.mlir" + "slice.mlir" + "sort.mlir" + "sqrt.mlir" + "subtract.mlir" + "tanh.mlir" + "torch_index_select.mlir" + "transpose.mlir" + "while.mlir" TARGET_BACKEND "vmla" DRIVER diff --git a/iree/tools/test/BUILD b/iree/tools/test/BUILD index 0978d1c7b113..bcabe82a6d9b 100644 --- a/iree/tools/test/BUILD +++ b/iree/tools/test/BUILD @@ -15,6 +15,7 @@ # Smoke tests for the execution of tool binaries. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -24,7 +25,18 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "iree-benchmark-module.mlir", + "iree-run-mlir.mlir", + "iree-run-module.mlir", + "multiple_args.mlir", + "multiple_exported_functions.mlir", + "repeated_return.mlir", + "scalars.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-benchmark-module", diff --git a/iree/tools/test/CMakeLists.txt b/iree/tools/test/CMakeLists.txt index 1342a45f7b87..bf6581f0422c 100644 --- a/iree/tools/test/CMakeLists.txt +++ b/iree/tools/test/CMakeLists.txt @@ -10,12 +10,17 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "iree-benchmark-module.mlir" + "iree-run-mlir.mlir" + "iree-run-module.mlir" + "multiple_args.mlir" + "multiple_exported_functions.mlir" + "repeated_return.mlir" + "scalars.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-benchmark-module From 565c97c381e46bd8eaf8618783de1a462cc9f962 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Thu, 11 Mar 2021 19:45:28 +0100 Subject: [PATCH 09/17] Add vm.const.i{32,64}.zero to emitc conversions (#5066) --- .../Conversion/VMToEmitC/ConvertVMToEmitC.cpp | 17 +++++++++++++++++ .../VM/Conversion/VMToEmitC/test/const_ops.mlir | 16 +++++++++++++--- .../VMToEmitC/test/const_ops_i64.mlir | 16 +++++++++++++--- 3 files changed, 43 insertions(+), 6 deletions(-) diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp index a788be853c31..9664c6789c06 100644 --- a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp +++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp @@ -103,6 +103,21 @@ class ConstOpConversion : public OpRewritePattern { } }; +template +class ConstZeroOpConversion : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ConstZeroOpTy constZeroOp, + PatternRewriter &rewriter) const final { + auto type = constZeroOp.getType(); + IntegerAttr value = rewriter.getIntegerAttr(type, 0); + + rewriter.replaceOpWithNewOp(constZeroOp, type, value); + return success(); + } +}; + template class GlobalLoadOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -190,6 +205,7 @@ void populateVMToCPatterns(MLIRContext *context, // Constants patterns.insert>(context); + patterns.insert>(context); // Conditional assignment ops patterns.insert>(context, @@ -253,6 +269,7 @@ void populateVMToCPatterns(MLIRContext *context, // ExtI64: Constants patterns.insert>(context); + patterns.insert>(context); // ExtI64: Conditional assignment ops patterns.insert>(context, diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/const_ops.mlir b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/const_ops.mlir index 451bf7d81904..b17486cc2839 100644 --- a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/const_ops.mlir +++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/const_ops.mlir @@ -1,7 +1,18 @@ // RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-convert-vm-to-emitc)' %s | IreeFileCheck %s -// CHECK: vm.module @module { -vm.module @module { + +vm.module @my_module { + // CHECK-LABEL: vm.func @const_i32_zero + vm.func @const_i32_zero() -> i32 { + // CHECK: %zero = "emitc.const"() {value = 0 : i32} : () -> i32 + %zero = vm.const.i32.zero : i32 + vm.return %zero : i32 + } +} + +// ----- + +vm.module @my_module { // CHECK-LABEL: vm.func @const_i32 vm.func @const_i32() { // CHECK-NEXT: %0 = "emitc.const"() {value = 0 : i32} : () -> i32 @@ -10,7 +21,6 @@ vm.module @module { %1 = vm.const.i32 2 : i32 // CHECK-NEXT: %2 = "emitc.const"() {value = -2 : i32} : () -> i32 %2 = vm.const.i32 -2 : i32 - // CHECK-NEXT: vm.return vm.return } } diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/const_ops_i64.mlir b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/const_ops_i64.mlir index 89b205ad8a1a..0161dc4bc29e 100644 --- a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/const_ops_i64.mlir +++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/const_ops_i64.mlir @@ -1,7 +1,18 @@ // RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-convert-vm-to-emitc)' %s | IreeFileCheck %s -// CHECK: vm.module @module { -vm.module @module { + +vm.module @my_module { + // CHECK-LABEL: vm.func @const_i64_zero + vm.func @const_i64_zero() -> i64 { + // CHECK: %zero = "emitc.const"() {value = 0 : i64} : () -> i64 + %zero = vm.const.i64.zero : i64 + vm.return %zero : i64 + } +} + +// ----- + +vm.module @my_module { // CHECK-LABEL: vm.func @const_i64 vm.func @const_i64() { // CHECK-NEXT: %0 = "emitc.const"() {value = 0 : i64} : () -> i64 @@ -10,7 +21,6 @@ vm.module @module { %1 = vm.const.i64 2 : i64 // CHECK-NEXT: %2 = "emitc.const"() {value = -2 : i64} : () -> i64 %2 = vm.const.i64 -2 : i64 - // CHECK-NEXT: vm.return vm.return } } From 7116f0ce2b0917ec5530a39b9e698b2be3246e32 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Thu, 11 Mar 2021 13:54:34 -0500 Subject: [PATCH 10/17] Plumb GPU conv and matmul vectorization through flow.dispatch.workgroups (#4999) This commit adds all necessary plumbing to connect 2-D convolution and matmul vectorization inside flow.dispatch.workgroups. This includes: * Let 2-D convolution be recognized as a root op during dispatch region formation. * Add a pass to concretize the abstract tiling and distribution during flow dispatch region formation. It substitutes symbolic ops with concrete values from CodeGen policy. * Recognize `hal.interface.workgroup.id` ops when folding GPU processor ID uses. * Recognize `hal.interface.binding.subspan` when vectorizing `memref`s for better memory access. Along the way, the old path is refactored to have a better structure for further cleaning up: * `LinalgTileAndDistributePass` is split out of `LinalgTileAndFusePass`. It will be used for tiling and distribution among workgroups in the old path. `LinalgTileAndFusePass` will be for tiling and vectorization in a single workgroup (and it will be renamed later). * A few tests are updated to use the new path. --- .../Conversion/Common/LaunchConfig.cpp | 8 +- .../compiler/Conversion/Common/LaunchConfig.h | 4 - .../test/tile_and_distribute.mlir | 3 +- iree/compiler/Conversion/LinalgToSPIRV/BUILD | 3 + .../Conversion/LinalgToSPIRV/CMakeLists.txt | 3 + .../ConcretizeTileAmongWorkgroupsPass.cpp | 565 ++++++++++++++++++ .../LinalgToSPIRV/FoldGPUProcessorIDUses.cpp | 39 +- .../LinalgToSPIRV/KernelDispatchUtils.cpp | 93 ++- .../LinalgTileAndDistributePass.cpp | 155 +++++ .../LinalgToSPIRV/LinalgTileAndFusePass.cpp | 84 +-- .../Conversion/LinalgToSPIRV/Passes.cpp | 8 +- .../Conversion/LinalgToSPIRV/Passes.h | 10 + .../LinalgToSPIRV/VectorizeMemref.cpp | 40 +- .../test/batch_matmul_vectorization.mlir | 2 +- .../concretize_tile_among_workgroups.mlir | 225 +++++++ .../test/fold-gpu-procid-uses.mlir | 34 ++ .../test/linalg_tile_and_fuse.mlir | 147 +---- .../materialize_launch_configuration.mlir | 2 +- .../test/matmul_fused_vectorization.mlir | 66 +- .../test/matmul_vectorization.mlir | 4 +- .../test/memref_vecrotization.mlir | 21 + .../test/tile_and_vectorize_conv.mlir | 176 ++++++ .../test/tile_and_vectorize_matmul.mlir | 61 ++ .../test/workgroup_memory_promotion.mlir | 2 +- iree/compiler/Conversion/init_conversions.h | 1 + .../Transforms/DispatchLinalgOnTensors.cpp | 53 +- .../test/dispatch_linalg_on_tensors.mlir | 30 + iree/test/e2e/vulkan_specific/BUILD | 2 + iree/test/e2e/vulkan_specific/CMakeLists.txt | 2 + 29 files changed, 1509 insertions(+), 334 deletions(-) create mode 100644 iree/compiler/Conversion/LinalgToSPIRV/ConcretizeTileAmongWorkgroupsPass.cpp create mode 100644 iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndDistributePass.cpp create mode 100644 iree/compiler/Conversion/LinalgToSPIRV/test/concretize_tile_among_workgroups.mlir create mode 100644 iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_conv.mlir create mode 100644 iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_matmul.mlir diff --git a/iree/compiler/Conversion/Common/LaunchConfig.cpp b/iree/compiler/Conversion/Common/LaunchConfig.cpp index 88b99f533d2f..246257494f15 100644 --- a/iree/compiler/Conversion/Common/LaunchConfig.cpp +++ b/iree/compiler/Conversion/Common/LaunchConfig.cpp @@ -28,6 +28,7 @@ #include "llvm/Support/FormatVariadic.h" #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" @@ -38,6 +39,7 @@ namespace iree_compiler { /// Name of the StrAttr that can be used to get the key to access the tile size /// information. static const char kLaunchInfoKey[] = "launch_info_key"; +static const char kRootOpKey[] = "is_root_op"; static Optional getKey(Operation *op) { StringAttr attr = op->getAttrOfType(kLaunchInfoKey); @@ -82,8 +84,7 @@ ArrayRef LaunchConfig::getTileSizes(Operation *op, Operation *LaunchConfig::getRootOperation(ArrayRef ops) { for (auto op : ops) { - auto key = getKey(op); - if (key && key.getValue() == rootOperationKey) return op; + if (op->getAttrOfType(kRootOpKey)) return op; } return nullptr; } @@ -114,8 +115,7 @@ void LaunchConfig::setNumSubgroups(ArrayRef vNumSubgroups) { } void LaunchConfig::setRootOperation(Operation *op) { - Optional key = getKey(op); - if (key) rootOperationKey = *key; + op->setAttr(kRootOpKey, UnitAttr::get(op->getContext())); } void LaunchConfig::setSameConfig(Operation *source, Operation *target) { diff --git a/iree/compiler/Conversion/Common/LaunchConfig.h b/iree/compiler/Conversion/Common/LaunchConfig.h index 7f6f3366c788..1367f6054f05 100644 --- a/iree/compiler/Conversion/Common/LaunchConfig.h +++ b/iree/compiler/Conversion/Common/LaunchConfig.h @@ -127,10 +127,6 @@ class LaunchConfig { /// these attributes. llvm::StringMap tileSizes; - /// Key used for tagging the root operation. The launch config does not track - /// the root operation itself, but rather the key used for the root operation. - StringRef rootOperationKey = ""; - /// Workgroup size to use. std::array workgroupSize = {1, 1, 1}; diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/tile_and_distribute.mlir b/iree/compiler/Conversion/LinalgToLLVM/test/tile_and_distribute.mlir index 4868d059698f..a430f6334824 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/test/tile_and_distribute.mlir +++ b/iree/compiler/Conversion/LinalgToLLVM/test/tile_and_distribute.mlir @@ -84,4 +84,5 @@ hal.executable @static_matmul attributes {sym_visibility = "private"} { // CHECK: %[[J:.+]] = affine.apply #[[MAP2]]()[%[[THREAD_X_ID]]] // CHECK: %[[RHS_SUBVIEW:.+]] = subview %[[RHS]][%[[K]], %[[J]]] [1, 4] [1, 1] : memref<4x8xf32> to memref<1x4xf32, #[[MAP3]]> // CHECK: %[[RESULT_SUBVIEW:.+]] = subview %[[RESULT]][%[[I]], %[[J]]] [2, 4] [1, 1] : memref<16x8xf32> to memref<2x4xf32, #[[MAP3]]> -// CHECK: linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%[[LHS_SUBVIEW]], %[[RHS_SUBVIEW]] : memref<2x1xf32, #[[MAP1]]>, memref<1x4xf32, #[[MAP3]]>) outs(%4 : memref<2x4xf32, #[[MAP3]]>) +// CHECK: linalg.matmul +//CHECK-SAME: ins(%[[LHS_SUBVIEW]], %[[RHS_SUBVIEW]] : memref<2x1xf32, #[[MAP1]]>, memref<1x4xf32, #[[MAP3]]>) outs(%4 : memref<2x4xf32, #[[MAP3]]>) diff --git a/iree/compiler/Conversion/LinalgToSPIRV/BUILD b/iree/compiler/Conversion/LinalgToSPIRV/BUILD index b224b665d763..2aab244cd670 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/BUILD +++ b/iree/compiler/Conversion/LinalgToSPIRV/BUILD @@ -34,11 +34,13 @@ cc_library( cc_library( name = "LinalgToSPIRV", srcs = [ + "ConcretizeTileAmongWorkgroupsPass.cpp", "ConvertToGPUPass.cpp", "ConvertToSPIRVPass.cpp", "CooperativeMatrixAnalysis.cpp", "FoldGPUProcessorIDUses.cpp", "KernelDispatchUtils.cpp", + "LinalgTileAndDistributePass.cpp", "LinalgTileAndFusePass.cpp", "MatMulVectorizationTest.cpp", "Passes.cpp", @@ -61,6 +63,7 @@ cc_library( "//iree/compiler/Conversion/HLOToHLO", "//iree/compiler/Conversion/HLOToLinalg", "//iree/compiler/Conversion/LinalgToVector", + "//iree/compiler/Dialect/Flow/IR", "//iree/compiler/Dialect/HAL/IR", "//iree/compiler/Dialect/HAL/IR:HALDialect", "//iree/compiler/Dialect/IREE/IR", diff --git a/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt b/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt index e51da222e2eb..7046ccca4ae4 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt +++ b/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt @@ -32,11 +32,13 @@ iree_cc_library( "Passes.h" "Utils.h" SRCS + "ConcretizeTileAmongWorkgroupsPass.cpp" "ConvertToGPUPass.cpp" "ConvertToSPIRVPass.cpp" "CooperativeMatrixAnalysis.cpp" "FoldGPUProcessorIDUses.cpp" "KernelDispatchUtils.cpp" + "LinalgTileAndDistributePass.cpp" "LinalgTileAndFusePass.cpp" "MatMulVectorizationTest.cpp" "Passes.cpp" @@ -74,6 +76,7 @@ iree_cc_library( iree::compiler::Conversion::HLOToHLO iree::compiler::Conversion::HLOToLinalg iree::compiler::Conversion::LinalgToVector + iree::compiler::Dialect::Flow::IR iree::compiler::Dialect::HAL::IR iree::compiler::Dialect::HAL::IR::HALDialect iree::compiler::Dialect::IREE::IR diff --git a/iree/compiler/Conversion/LinalgToSPIRV/ConcretizeTileAmongWorkgroupsPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/ConcretizeTileAmongWorkgroupsPass.cpp new file mode 100644 index 000000000000..6973241ddae7 --- /dev/null +++ b/iree/compiler/Conversion/LinalgToSPIRV/ConcretizeTileAmongWorkgroupsPass.cpp @@ -0,0 +1,565 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//===- ConcretizeTileAmongWorkgroupsPass.cpp ------------------------------===// +// +// This pass concretizes hal.interface.workgroup ops by replacing them with +// constant values from the chosen tiling and distribution scheme. +// +// During dispatch region formation in IREE Flow transformations, ops are tiled +// and distributed in an abstract way by using symbolic hal.interface.workgroup +// ops. That is because the same source region is compiled towards different +// target backends and each target backend could use different tiling and +// distribution schemes. However, after HAL interface materialization, the +// hal.executable.target is just meant for one target backend. We need to +// concretize the tiling and distribution in order to inject static information +// for further compilation. +// +// This pass performs the conretization in two modes: +// +// 1) Partically static: where have a concrete tiling and distirbution sheme +// *but not* a full static original problem size (e.g., due to dynamic +// shapes). Under such circumstances, we can only replace ops like +// hal.interface.workgroup.size ops and still need to compute the number +// of workgroups using symbolic values. +// 2) Fully static: where we have a concrete tiling and distribution scheme +// *and* the full static original problem size. Under such circumstances, +// we can fully deduce the number of workgroups to dispatch and replace +// hal.interface.workgroup.count ops with constant values too. +// +//===----------------------------------------------------------------------===// + +#include "iree/compiler/Conversion/Common/LaunchConfig.h" +#include "iree/compiler/Conversion/Common/Transforms.h" +#include "iree/compiler/Conversion/LinalgToSPIRV/CodeGenOptionUtils.h" +#include "iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h" +#include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h" +#include "iree/compiler/Conversion/LinalgToSPIRV/Utils.h" +#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "iree-spirv-concretize-tile-among-workgroups" + +namespace mlir { +namespace iree_compiler { + +namespace { + +constexpr unsigned kWorkgroupDimCount = 3; + +int64_t ceilDiv(int64_t a, int64_t b) { return (a + b - 1) / b; } + +static size_t getNumOuterParallelDims(linalg::LinalgOp op) { + ArrayRef iterators = op.iterator_types().getValue(); + auto parallels = iterators.take_while( + [](Attribute attr) { return linalg::isParallelIteratorType(attr); }); + return parallels.size(); +} + +/// Returns the root Linalg op that dictates tiling and distribution policy. +linalg::LinalgOp getRootLinalgOp(FuncOp funcOp) { + SmallVector linalgOps; + SmallVector tiledLoops; + if (failed(getLinalgOps(funcOp, linalgOps, tiledLoops))) return {}; + + SPIRVCodegenOptions options; + options.enableVectorization = true; + options.usingLinalgOnTensors = true; + linalg::Aliases aliases; + linalg::LinalgDependenceGraph dependenceGraph(aliases, linalgOps); + Optional launchConfigOpt = initGPULaunchConfig( + funcOp.getContext(), dependenceGraph, options, linalgOps); + if (!launchConfigOpt) return {}; + + LaunchConfig &launchConfig = *launchConfigOpt; + Operation *rootOp = + launchConfig.getRootOperation(llvm::to_vector<4>(llvm::map_range( + linalgOps, [](linalg::LinalgOp op) { return op.getOperation(); }))); + + // Clean up internal markers that are set during launch configuration + // preparation. + launchConfig.finalize(funcOp); + + return rootOp; +} + +/// Assuming the given `rootOp` is the tiled root Linalg op, returns the +/// original input/output types for all tiles. +/// +/// Note: After the abstract tiling and distribution in Flow dispatch region +/// creation, the anchoring root op is already in a loop nest and works on a +/// tile. The full type for all tiles in the IR is not explicit anymore after +/// HAL interface is materialized. So go through the IR use chain to figure it +/// out. Otherwise we need to make even more assumptions in the following. +// TODO(antiagainst): This is quite fragile. We need a better way to pass the +// information down from the upper layer, which readily has it. Probably via +// linalg.tile op. +LogicalResult getInputOutputTypesForAllTiles( + linalg::LinalgOp rootOp, SmallVectorImpl &inputTypes, + SmallVectorImpl &outputTypes) { + for (Value inputBuffer : rootOp.getInputBuffers()) { + auto subviewOp = inputBuffer.getDefiningOp(); + if (!subviewOp) return failure(); + inputTypes.push_back(subviewOp.getViewSource().getType()); + } + + for (Value outputBuffer : rootOp.getOutputBuffers()) { + auto subviewOp = outputBuffer.getDefiningOp(); + if (!subviewOp) return failure(); + outputTypes.push_back(subviewOp.getViewSource().getType()); + } + + return success(); +} + +/// Assuming the given `rootOp` is the tiled root Linalg op, returns the +/// tile sizes for distributing to workgroups and the workgroups size for the +/// generated kernel. +/// +/// TODO(antiagainst): This pass can be shared between CPU and GPU. But the +/// following query scopes it to GPU for now. +llvm::Optional, ArrayRef>> +getTileSizeAndWorkgroupSize(Operation *rootOp, ArrayRef inputTypes, + ArrayRef outputTypes) { + // Build necesary structures to query the tile sizes for distributing to + // workgroups. + linalg::Aliases aliases; + SmallVector linalgOps; + auto ops = rootOp->getBlock()->getOps(); + linalgOps.assign(ops.begin(), ops.end()); + linalg::LinalgDependenceGraph dependenceGraph(aliases, linalgOps); + SPIRVCodegenOptions options; + + // NOTE: Launch configuration expects the original input/output type to decide + // the configuration. But we have already tiled the Linalg ops here. Use an + // attribute to send it over for now. + const char inputTypeAttrName[] = "iree.codegen.original_input_types"; + const char outputTypeAttrName[] = "iree.codegen.original_output_types"; + if (!inputTypes.empty()) { + rootOp->setAttr(inputTypeAttrName, + Builder(rootOp).getTypeArrayAttr(inputTypes)); + } + if (!outputTypes.empty()) { + rootOp->setAttr(outputTypeAttrName, + Builder(rootOp).getTypeArrayAttr(outputTypes)); + } + + Optional launchConfig = initGPULaunchConfig( + rootOp->getContext(), dependenceGraph, options, linalgOps); + if (!launchConfig) { + rootOp->emitError("unable to find launch configuration"); + return llvm::None; + } + + ArrayRef tileSize = launchConfig->getTileSizes(rootOp, 0); + ArrayRef workgroupSize = launchConfig->getWorkgroupSize(); + + // Clean up internal markers that are set during launch configuration + // preparation. + launchConfig->finalize(rootOp->getParentOfType()); + + return std::make_pair(tileSize, workgroupSize); +} + +/// Replaces hal.interface.workgroup.size op with the constant value chosen +/// from tiling scheme. +class ConcretizeWorkgroupSizeOp final + : public OpRewritePattern { + public: + ConcretizeWorkgroupSizeOp(MLIRContext *context, + SmallVector workloadSize, + SmallVector tileSize, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + workloadSize(std::move(workloadSize)), + tileSize(std::move(tileSize)) {} + + LogicalResult matchAndRewrite(IREE::HAL::InterfaceWorkgroupSizeOp op, + PatternRewriter &rewriter) const override { + unsigned dimIndex = op.dimension().getZExtValue(); + + if (dimIndex < kWorkgroupDimCount && tileSize[dimIndex] != 0) { + rewriter.replaceOpWithNewOp( + op, rewriter.getIndexAttr(tileSize[dimIndex])); + return success(); + } + + return failure(); + } + + private: + SmallVector workloadSize; + SmallVector tileSize; +}; + +/// Replaces hal.interface.workgroup.count op with the constant value chosen +/// from tiling scheme. +class ConcretizeWorkgroupCountOp final + : public OpRewritePattern { + public: + ConcretizeWorkgroupCountOp(MLIRContext *context, + SmallVector workloadSize, + SmallVector tileSize, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + workloadSize(workloadSize), + tileSize(std::move(tileSize)) {} + + LogicalResult matchAndRewrite(IREE::HAL::InterfaceWorkgroupCountOp op, + PatternRewriter &rewriter) const override { + unsigned dimIndex = op.dimension().getZExtValue(); + + if (dimIndex >= kWorkgroupDimCount) return failure(); + + int64_t dimSize = workloadSize[dimIndex]; + int64_t dimTile = tileSize[dimIndex]; + + if (dimSize == ShapedType::kDynamicSize || dimTile == 0) return failure(); + + int64_t count = ceilDiv(dimSize, dimTile); + rewriter.replaceOpWithNewOp(op, rewriter.getIndexAttr(count)); + + return success(); + } + + private: + SmallVector workloadSize; + SmallVector tileSize; +}; + +// Canonicalizes away a trip-one scf.for loop by inlining its body and removing +// the loop. +// +// This pattern is needed because in Flow abstract tiling and distribution we +// will create scf.for loops that distribute workload cyclically. After +// concretizing hal.interface.workgroup.* ops, these scf.for loops still remain, +// and they will be of the form: +// +// %lb = mul %workgroup_id_{x|y|z}, %cst_tile_size_{x|y|z} +// scf.for %iv = %lb to %cst_wokload_size_{x|y|z} +// step %cst_workload_size_{x|y|z} { ... } +// +// Such scf.for loops can be inlined if %lb is smaller than upper bound. +class RemoveTripOneLoop final : public OpRewritePattern { + public: + RemoveTripOneLoop(MLIRContext *context, SmallVector workloadSize, + SmallVector tileSize, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + workloadSize(workloadSize), + tileSize(std::move(tileSize)) {} + + LogicalResult matchAndRewrite(scf::ForOp op, + PatternRewriter &rewriter) const override { + // Get constant upper bound and step values. + IntegerAttr ub, step; + if (!matchPattern(op.upperBound(), m_Constant(&ub)) || + !matchPattern(op.step(), m_Constant(&step))) { + return failure(); + } + + // Require that they are the same. + if (ub != step) return failure(); + + // Now make sure the lower bound is smaller than upper bound. The lower + // bound should be multiplying the workgroup ID with some constant. + + auto mulOp = op.lowerBound().getDefiningOp(); + if (!mulOp || mulOp.mapOperands().size() != 2) return failure(); + + AffineExpr lhs, rhs; + bindSymbols(op.getContext(), lhs, rhs); + auto mulMap = AffineMap::get(0, 2, lhs * rhs); + if (mulOp.getAffineMap() != mulMap) return failure(); + + auto mulLhs = mulOp.mapOperands().front(); + auto mulRhs = mulOp.mapOperands().back(); + + auto idOp = mulLhs.getDefiningOp(); + IntegerAttr multipler; + if (!idOp || !matchPattern(mulRhs, m_Constant(&multipler))) + return failure(); + + // We just need to make sure the max value of the workgroup ID multipled by + // the multipler is smaller than the upper bound to guarantee one trip. + unsigned dimIndex = idOp.dimension().getZExtValue(); + int64_t dimSize = workloadSize[dimIndex]; + int64_t dimTile = tileSize[dimIndex]; + + if (dimSize == ShapedType::kDynamicSize) return failure(); + + int64_t count = ceilDiv(dimSize, dimTile); + assert(count > 0 && "expected at least one tile!"); + + // ID should be in range [0, count). + if ((count - 1) * multipler.getInt() >= ub.getInt()) { + // Dead loop. It can actually be removed entirely. But we aren't expecting + // it to happen here. Do not canonicalize for such case. + return failure(); + } + + SmallVector blockArgs; + blockArgs.reserve(op.getNumIterOperands() + 1); + blockArgs.push_back(op.lowerBound()); + llvm::append_range(blockArgs, op.getIterOperands()); + + Block *block = &op.getLoopBody().front(); + Operation *terminator = block->getTerminator(); + ValueRange results = terminator->getOperands(); + rewriter.mergeBlockBefore(block, op, blockArgs); + rewriter.replaceOp(op, results); + rewriter.eraseOp(terminator); + + return success(); + } + + private: + SmallVector workloadSize; + SmallVector tileSize; +}; + +/// Concretizes hal.interface.workgroup.* ops with constants from the chosen +/// tiling sheme when possible and perform loop canonicalization afterwards. +class ConcretizeTileAmongWorkgroupsPass + : public PassWrapper> { + public: + ConcretizeTileAmongWorkgroupsPass(const SPIRVCodegenOptions &options) + : options(options) {} + ConcretizeTileAmongWorkgroupsPass( + const ConcretizeTileAmongWorkgroupsPass &that) + : options(that.options) { + inlineTripOneLoops = that.inlineTripOneLoops; + } + + void runOnOperation() override { + IREE::HAL::ExecutableTargetOp targetOp = getOperation(); + ModuleOp module = targetOp.getInnerModule(); + + for (FuncOp funcOp : module.getOps()) { + if (!funcOp.isPublic()) continue; + if (failed(runOnFunction(funcOp))) return signalPassFailure(); + } + } + + private: + LogicalResult runOnFunction(FuncOp funcOp) { + MLIRContext &context = getContext(); + + // 1. Get the root op first. We need it to figure out the original problem + // size, which then affects the tiling and distribution policy. + + linalg::LinalgOp rootOp = getRootLinalgOp(funcOp); + if (!rootOp) { + LLVM_DEBUG(llvm::dbgs() << "unable to find root Linalg op\n"); + // It can happen for ops that are not abstractly tiled during dispatch + // region formation. So don't trigger pass failure. + return success(); + } + LLVM_DEBUG(llvm::dbgs() << "Root op: " << rootOp << "\n"); + + size_t numTilableDims = getNumOuterParallelDims(rootOp); + + // 2. Figure out the original problem size. + + SmallVector inputTypes, outputTypes; + SmallVector workloadSize; + if (succeeded( + getInputOutputTypesForAllTiles(rootOp, inputTypes, outputTypes))) { + if (outputTypes.size() != 1) { + return rootOp.emitError("only support ops with one result right now"); + } + + // Flow/HAL processor id/size/count ops' indices follow the reverse order + // of the shape dimensions. + workloadSize = llvm::to_vector<4>(llvm::reverse( + outputTypes.front().cast().getShape().take_front( + numTilableDims))); + } else { + // This can happen for dynamic shapes. + LLVM_DEBUG(llvm::dbgs() + << "unable to find input/output type for all tiles"); + + inputTypes.clear(); + outputTypes.clear(); + + workloadSize.assign(numTilableDims, ShapedType::kDynamicSize); + } + + LLVM_DEBUG({ + llvm::dbgs() << "Queried workload size: "; + llvm::interleaveComma(workloadSize, llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + + // 3. Query the scheme for tiling among workgroups. + + SmallVector tileSize; + SmallVector workgroupSize; + + // Try to use configuration from the command-line first for testing. + tileSize.assign(options.tileSizes.begin(), options.tileSizes.end()); + tileSize.resize(numTilableDims, 0); + workgroupSize.assign(options.workgroupSize.begin(), + options.workgroupSize.end()); + if (tileSize.empty() || workgroupSize.empty()) { + auto sizes = getTileSizeAndWorkgroupSize(rootOp, inputTypes, outputTypes); + if (sizes) { + // The tile sizes are specified against the original dimension order of + // the workload shape. But Flow/HAL processor id/size/count ops' are + // created using the reverse order. + tileSize = llvm::to_vector<4>( + llvm::reverse(sizes->first.take_front(numTilableDims))); + workgroupSize = llvm::to_vector<4>(sizes->second); + } else { + return funcOp.emitError("failed to query tile size and workgroup size"); + } + } + + LLVM_DEBUG({ + llvm::dbgs() << "Queried tile size: "; + llvm::interleaveComma(tileSize, llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + + // 4. Replace hal.interface.workgroup symbolic ops with constant values. + + { + OwningRewritePatternList patterns; + patterns.insert( + &context, workloadSize, tileSize); + + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + } + + LLVM_DEBUG({ + llvm::dbgs() + << "--- After concretizing hal.interface.workgroup ops ---\n"; + funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); + llvm::dbgs() << "\n\n"; + }); + + // 5. Set the entry point region for computing the number of workgroups + // to dispatch. The region has symbolic arguments representing the workload. + // So two modes here (see comments at the begining of this file). + + { + SmallVector numWorkgroups; + for (auto pair : llvm::zip(workloadSize, tileSize)) { + auto workload = std::get<0>(pair); + auto tile = std::get<1>(pair); + if (workload == ShapedType::kDynamicSize || tile == 0) { + numWorkgroups.push_back(ShapedType::kDynamicSize); + } else { + numWorkgroups.push_back(ceilDiv(workload, tile)); + } + } + + numWorkgroups.resize(kWorkgroupDimCount, 1); + + // If all dimensions are known constant, then we can set the number of + // workgroups directly. Otherwise, we need to generate the IR for + // computing it using symbolic values. + if (llvm::none_of(numWorkgroups, [](int64_t dim) { + return dim == ShapedType::kDynamicSize; + })) { + OpBuilder builder(&context); + WorkgroupCountRegionBuilder regionBuilder = + [&](OpBuilder &builder, Location loc, std::array) { + std::array returnValues; + for (unsigned i = 0; i < kWorkgroupDimCount; ++i) { + returnValues[i] = + builder.create(loc, numWorkgroups[i]); + } + return returnValues; + }; + if (failed( + defineWorkgroupCountRegion(builder, funcOp, regionBuilder))) { + return funcOp.emitError( + "failed to set entry point region for number of workgroups"); + } + } else { + if (failed(materializeStaticLaunchInformation(funcOp, tileSize))) { + return funcOp.emitOpError( + "failed to materialize static launch information"); + } + } + } + + if (failed(updateWorkGroupSize(funcOp, workgroupSize))) { + return funcOp.emitOpError("failed to set workgroup size on function"); + } + + // 6. Canonicalization and clean up. + + if (inlineTripOneLoops) { + OwningRewritePatternList patterns; + patterns.insert(&context, workloadSize, tileSize); + + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + } + + return success(); + } + + private: + SPIRVCodegenOptions options; + + // TODO(#5034): Investigate whether there is a better way to prove tileability + // and canonicalize affine.min ops, without matching against the specific + // pattern involving loops. + Option inlineTripOneLoops{ + *this, "inline-trip-one-loops", + llvm::cl::desc( + "Inline a loop's body if it can be proven to just have one trip"), + llvm::cl::init(true)}; +}; + +} // namespace + +std::unique_ptr> +createConcretizeTileAmongWorkgroupsPass(const SPIRVCodegenOptions &options) { + return std::make_unique(options); +} + +static PassRegistration pass( + "iree-spirv-concretize-tile-among-workgroups", + "Replace hal.interface.workgroup.* ops with constant values from chosen " + "tiling and distribution scheme", + [] { + SPIRVCodegenOptions options = getSPIRVCodegenOptionsFromClOptions(); + return std::make_unique(options); + }); + +} // namespace iree_compiler +} // namespace mlir diff --git a/iree/compiler/Conversion/LinalgToSPIRV/FoldGPUProcessorIDUses.cpp b/iree/compiler/Conversion/LinalgToSPIRV/FoldGPUProcessorIDUses.cpp index d43044f52e40..323832d9e707 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/FoldGPUProcessorIDUses.cpp +++ b/iree/compiler/Conversion/LinalgToSPIRV/FoldGPUProcessorIDUses.cpp @@ -20,6 +20,7 @@ #include "iree/compiler/Conversion/Common/Attributes.h" #include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h" +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" @@ -89,26 +90,34 @@ int dimensionToIndex(StringRef dimension) { return StringSwitch(dimension).Case("x", 0).Case("y", 1).Case("z", 2); } -/// Gets the block processor ID's upper bound. This queries the workgroup count -/// function. -Optional getProcessorIDUpperBound(gpu::BlockIdOp blockIDOp) { - auto funcOp = blockIDOp->getParentOfType(); +IREE::HAL::ReturnOp getEntryPointReturnOp(Operation *op) { + auto funcOp = op->getParentOfType(); auto targetOp = funcOp.getOperation()->getParentOfType(); - IREE::HAL::ExecutableEntryPointOp entryPointOp = nullptr; + + IREE::HAL::ExecutableEntryPointOp entryPointOp; for (auto op : targetOp.getOps()) { if (op.sym_name() == funcOp.getName()) { entryPointOp = op; break; } } - if (!entryPointOp) return llvm::None; + if (!entryPointOp) return {}; Operation *terminator = entryPointOp.getBlock()->getTerminator(); auto retOp = dyn_cast(terminator); - if (!retOp || retOp.getNumOperands() != 3) return llvm::None; + if (!retOp || retOp.getNumOperands() != 3) return {}; + LLVM_DEBUG(llvm::dbgs() << "workgroup count function return op: " << retOp << "\n"); + return retOp; +} + +/// Gets the block processor ID's upper bound. This queries the workgroup count +/// function. +Optional getProcessorIDUpperBound(gpu::BlockIdOp blockIDOp) { + auto retOp = getEntryPointReturnOp(blockIDOp); + if (!retOp) return llvm::None; int index = dimensionToIndex(blockIDOp.dimension()); IntegerAttr attr; @@ -118,6 +127,19 @@ Optional getProcessorIDUpperBound(gpu::BlockIdOp blockIDOp) { return attr.getInt(); } +Optional getProcessorIDUpperBound( + IREE::HAL::InterfaceWorkgroupIDOp blockIDOp) { + auto retOp = getEntryPointReturnOp(blockIDOp); + if (!retOp) return llvm::None; + + int index = blockIDOp.dimensionAttr().getInt(); + IntegerAttr attr; + if (!matchPattern(retOp.getOperand(index), m_Constant(&attr))) + return llvm::None; + + return attr.getInt(); +} + /// Gets the thread processor ID's upper bound. This queries the SPIR-V entry /// point ABI. Optional getProcessorIDUpperBound(gpu::ThreadIdOp threadIDOp) { @@ -165,6 +187,9 @@ struct FoldAffineMinOverProcessorID : OpRewritePattern { Optional ub; if (auto blockIDOp = dyn_cast(symbolOp)) { ub = getProcessorIDUpperBound(blockIDOp); + } else if (auto blockIDOp = + dyn_cast(symbolOp)) { + ub = getProcessorIDUpperBound(blockIDOp); } else if (auto threadIDOp = dyn_cast(symbolOp)) { ub = getProcessorIDUpperBound(threadIDOp); } diff --git a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp index 0ef687798f0e..097a0f8f1983 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp +++ b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp @@ -121,8 +121,23 @@ static LogicalResult getMaliSpecificConfig( std::array &numSubgroups) { if (targetEnv.getVendorID() != spirv::Vendor::ARM) return failure(); - auto lhsType = op.inputs()[0].getType().cast(); - auto rhsType = op.inputs()[1].getType().cast(); + ShapedType lhsType, rhsType; + // NOTE: Special treatment to let the flow.dispatch.workgroups path to be able + // to query launch configurations. + if (auto inputTypeAttr = + op->getAttrOfType("iree.codegen.original_input_types")) { + lhsType = inputTypeAttr.getValue()[0] + .cast() + .getValue() + .cast(); + rhsType = inputTypeAttr.getValue()[1] + .cast() + .getValue() + .cast(); + } else { + lhsType = op.inputs()[0].getType().cast(); + rhsType = op.inputs()[1].getType().cast(); + } assert(lhsType.getElementType() == rhsType.getElementType()); if (!lhsType.hasStaticShape() || !rhsType.hasStaticShape()) return failure(); // Get a vector of best tile size ordered from best to worst. @@ -292,8 +307,21 @@ static LogicalResult getTargetSpecificConfig( std::array &numSubgroups) { if (targetEnv.getVendorID() != spirv::Vendor::ARM) return failure(); - auto lhsType = op.inputs()[0].getType().cast(); - auto rhsType = op.inputs()[1].getType().cast(); + ShapedType lhsType, rhsType; + if (auto inputTypeAttr = + op->getAttrOfType("iree.codegen.original_input_types")) { + lhsType = inputTypeAttr.getValue()[0] + .cast() + .getValue() + .cast(); + rhsType = inputTypeAttr.getValue()[1] + .cast() + .getValue() + .cast(); + } else { + lhsType = op.inputs()[0].getType().cast(); + rhsType = op.inputs()[1].getType().cast(); + } assert(lhsType.getElementType() == rhsType.getElementType()); // If the shape size is unknonw fall back to none vectorized path. if (!lhsType.hasStaticShape() || !rhsType.hasStaticShape()) return failure(); @@ -375,14 +403,34 @@ template static LogicalResult getMaliSpecificConfig(ConvOpTy op, TileSizesListType &tileSizes, LaunchConfigInfo &config) { - auto inputType = op.getInput(1).getType().template cast(); - auto outputType = op.getOutputBufferTypes()[0].template cast(); + Operation *operation = op.getOperation(); + if (!isa(operation)) return failure(); + + ShapedType inputType, outputType; + + // NOTE: Special treatment to let the flow.dispatch.workgroups path to be able + // to query launch configurations. + if (auto outputTypeAttr = operation->getAttrOfType( + "iree.codegen.original_output_types")) { + auto inputTypeAttr = operation->getAttrOfType( + "iree.codegen.original_input_types"); + inputType = inputTypeAttr.getValue()[0] + .template cast() + .getValue() + .template cast(); + outputType = outputTypeAttr.getValue()[0] + .template cast() + .getValue() + .template cast(); + LLVM_DEBUG(llvm::dbgs() << "conv input types: " << inputType << "\n"); + LLVM_DEBUG(llvm::dbgs() << "conv output types: " << outputType << "\n"); + } else { + inputType = op.getInputs().front().getType().template cast(); + outputType = op.getOutputBufferTypes()[0].template cast(); + } + if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) return failure(); - // Only support NHWC conv. - if (!isa(op.getOperation())) { - return failure(); - } bool isInputTilable = inputType.getDimSize(3) % 4 == 0 || inputType.getDimSize(3) < 4; @@ -478,8 +526,29 @@ GET_CONV_LAUNCH_CONFIG(linalg::ConvInputNDHWCFilterDHWCFOp) static LogicalResult getMaliSpecificConfig( linalg::DepthwiseConvInputNHWCFilterHWCOp op, TileSizesListType &tileSizes, LaunchConfigInfo &config) { - auto inputType = op.getInput(0).getType().cast(); - auto outputType = op.getOutputBufferTypes()[0].cast(); + ShapedType inputType, outputType; + + // NOTE: Special treatment to let the flow.dispatch.workgroups path to be able + // to query launch configurations. + if (auto outputTypeAttr = + op->getAttrOfType("iree.codegen.original_output_types")) { + auto inputTypeAttr = + op->getAttrOfType("iree.codegen.original_input_types"); + inputType = inputTypeAttr.getValue()[0] + .template cast() + .getValue() + .template cast(); + outputType = outputTypeAttr.getValue()[0] + .template cast() + .getValue() + .template cast(); + LLVM_DEBUG(llvm::dbgs() << "dwconv input types: " << inputType << "\n"); + LLVM_DEBUG(llvm::dbgs() << "dwconv output types: " << outputType << "\n"); + } else { + inputType = op.getInput(0).getType().cast(); + outputType = op.getOutputBufferTypes()[0].cast(); + } + if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) return failure(); diff --git a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndDistributePass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndDistributePass.cpp new file mode 100644 index 000000000000..3c6d5e35e2bc --- /dev/null +++ b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndDistributePass.cpp @@ -0,0 +1,155 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//===- LinalgTileAndDistributePass.cpp ------------------------------------===// +// +// This pass tiles and distributes linalg operations among multiple workgroups. +// +// NOTE: Deprecated. This pass is used for the first-level tiling in the Linalg +// on buffers path, which is expected to go away soon. +// +//===----------------------------------------------------------------------===// + +#include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h" +#include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h" +#include "iree/compiler/Conversion/Common/Attributes.h" +#include "iree/compiler/Conversion/Common/Transforms.h" +#include "iree/compiler/Conversion/LinalgToSPIRV/CodeGenOptionUtils.h" +#include "iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h" +#include "iree/compiler/Conversion/LinalgToSPIRV/Utils.h" +#include "iree/compiler/Dialect/HAL/IR/HALDialect.h" +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" +#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "iree-linalg-to-spirv-tile-and-distribute" + +namespace mlir { +namespace iree_compiler { +namespace { + +/// Returns the distribution options for operations when targeting workgroups. +linalg::LinalgLoopDistributionOptions getWorkgroupDistributionOptions() { + linalg::LinalgLoopDistributionOptions options; + + options.procInfo = [](OpBuilder &builder, Location loc, + ArrayRef parallelLoopRanges) { + return getGPUProcessorIdsAndCounts( + builder, loc, parallelLoopRanges.size()); + }; + options.distributionMethod.assign( + 3, linalg::DistributionMethod::CyclicNumProcsEqNumIters); + + return options; +} + +class LinalgTileAndDistributePass + : public PassWrapper> { + public: + LinalgTileAndDistributePass(const SPIRVCodegenOptions &options) + : options(options) {} + LinalgTileAndDistributePass(const LinalgTileAndDistributePass &that) + : options(that.options) {} + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + IREE::HAL::ExecutableTargetOp targetOp = getOperation(); + ModuleOp module = targetOp.getInnerModule(); + + for (FuncOp funcOp : module.getOps()) { + if (!isEntryPoint(funcOp)) continue; + + SmallVector linalgOps; + SmallVector tiledLoops; + + if (failed(getLinalgOps(funcOp, linalgOps, tiledLoops))) { + return signalPassFailure(); + } + + linalg::Aliases aliases; + linalg::LinalgDependenceGraph dependenceGraph(aliases, linalgOps); + Optional launchConfigOpt = + initGPULaunchConfig(context, dependenceGraph, options, linalgOps); + if (!launchConfigOpt) { + funcOp.emitError("unable to find launch configuration"); + return signalPassFailure(); + } + LaunchConfig &launchConfig = *launchConfigOpt; + + LLVM_DEBUG({ + llvm::dbgs() + << "\n--- IREE Linalg tile and distribute configuration ---\n"; + llvm::dbgs() << "@func " << funcOp.getName() + << ": # workgroup sizes: ["; + interleaveComma(launchConfig.getWorkgroupSize(), llvm::dbgs()); + llvm::dbgs() << "]\n"; + for (auto op : linalgOps) { + llvm::dbgs() << "\t" << op.getOperation()->getName() << " : "; + TileSizesListTypeRef tileSizes = launchConfig.getTileSizes(op); + llvm::dbgs() << "{"; + std::string sep = ""; + for (auto &level : enumerate(tileSizes)) { + llvm::dbgs() << sep << level.index() << " : ["; + sep = ", "; + interleaveComma(level.value(), llvm::dbgs()); + llvm::dbgs() << "]"; + } + llvm::dbgs() << "}\n"; + } + }); + + TileAndFuseOptions tileAndFuseOptions = { + getWorkgroupDistributionOptions(), allocateWorkgroupMemory}; + if (failed(tileAndFuseLinalgBufferOps(funcOp, linalgOps, dependenceGraph, + launchConfig, + tileAndFuseOptions)) || + failed( + updateWorkGroupSize(funcOp, launchConfig.getWorkgroupSize()))) { + return signalPassFailure(); + } + } + } + + private: + SPIRVCodegenOptions options; +}; + +} // namespace + +std::unique_ptr> +createTileAndDistributeAmongWorkgroupsPass(const SPIRVCodegenOptions &options) { + return std::make_unique(options); +} + +static PassRegistration pass( + "iree-codegen-spirv-linalg-tile-and-distribute", + "Tile and distribute Linalg operations on buffers", [] { + SPIRVCodegenOptions options = getSPIRVCodegenOptionsFromClOptions(); + return std::make_unique(options); + }); + +} // namespace iree_compiler +} // namespace mlir diff --git a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp index 3aae4847324f..36d6ee2262c8 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp +++ b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp @@ -14,7 +14,8 @@ //===- LinalgTileAndFusePass.cpp - Tile and fuse Linalg on Buffers --------===// // -// Implements a pass to tile and fuse linalg operations on buffers. +// This pass tiles and vectorizes Linalg ops on buffers within in a single +// workgroup. // //===----------------------------------------------------------------------===// @@ -32,6 +33,8 @@ #include "iree/compiler/Dialect/HAL/IR/HALDialect.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Debug.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" @@ -72,33 +75,6 @@ static linalg::LinalgTransformationFilter getLinalgMatchAndReplaceMarker( markers, Identifier::get(replaceMarker, context)); } -/// Returns the distribution options for operations when targeting workgroups. -static linalg::LinalgLoopDistributionOptions getWorkgroupDistributionOptions() { - linalg::LinalgLoopDistributionOptions options; - - options.procInfo = [](OpBuilder &builder, Location loc, - ArrayRef parallelLoopRanges) { - return getGPUProcessorIdsAndCounts( - builder, loc, parallelLoopRanges.size()); - }; - options.distributionMethod.assign( - 3, linalg::DistributionMethod::CyclicNumProcsEqNumIters); - - return options; -} - -/// Applies canonicalization over index calculation inside the given `funcOp`. -static void applyIndexCalculationCanonicalization(FuncOp funcOp) { - MLIRContext *context = funcOp.getContext(); - OwningRewritePatternList canonicalizationPatterns; - DimOp::getCanonicalizationPatterns(canonicalizationPatterns, context); - AddIOp::getCanonicalizationPatterns(canonicalizationPatterns, context); - SubIOp::getCanonicalizationPatterns(canonicalizationPatterns, context); - SignedDivIOp::getCanonicalizationPatterns(canonicalizationPatterns, context); - (void)applyPatternsAndFoldGreedily(funcOp, - std::move(canonicalizationPatterns)); -} - //===----------------------------------------------------------------------===// // Main pass //===----------------------------------------------------------------------===// @@ -429,8 +405,6 @@ void LinalgTileAndFusePass::runOnOperation() { IREE::HAL::ExecutableTargetOp targetOp = getOperation(); ModuleOp module = targetOp.getInnerModule(); - LLVM_DEBUG( - llvm::dbgs() << "--- IREE Linalg tile and fuse configuration ---\n";); for (FuncOp funcOp : module.getOps()) { if (!isEntryPoint(funcOp)) continue; @@ -452,6 +426,7 @@ void LinalgTileAndFusePass::runOnOperation() { LaunchConfig &launchConfig = *launchConfigOpt; LLVM_DEBUG({ + llvm::dbgs() << "\n--- IREE Linalg tile and fuse configuration ---\n"; llvm::dbgs() << "@func " << funcOp.getName() << ": # workgroup sizes: ["; interleaveComma(launchConfig.getWorkgroupSize(), llvm::dbgs()); llvm::dbgs() << "]\n"; @@ -470,55 +445,6 @@ void LinalgTileAndFusePass::runOnOperation() { } }); - if (!options.usingLinalgOnTensors) { - TileAndFuseOptions tileAndFuseOptions = { - getWorkgroupDistributionOptions(), allocateWorkgroupMemory}; - if (failed(tileAndFuseLinalgBufferOps(funcOp, linalgOps, dependenceGraph, - launchConfig, - tileAndFuseOptions)) || - failed( - updateWorkGroupSize(funcOp, launchConfig.getWorkgroupSize()))) { - return signalPassFailure(); - } - } else { - // Find the root operation for the dispatch region and get the tile sizes. - Operation *rootOperation = - launchConfig.getRootOperation(llvm::to_vector<4>(llvm::map_range( - linalgOps, - [](linalg::LinalgOp op) { return op.getOperation(); }))); - if (!rootOperation) { - launchConfig.finalize(funcOp); - return; - } - - ArrayRef rootOperationTileSizes = - launchConfig.getTileSizes(rootOperation, 0); - if (rootOperationTileSizes.empty()) { - launchConfig.finalize(funcOp); - return; - } - - // Only use the tile sizes for parallel loops of the root operation. - rootOperationTileSizes = rootOperationTileSizes.take_front( - getNumOuterParallelLoops(rootOperation)); - - SmallVector workloadPerWorkgroup = - llvm::to_vector<4>(llvm::reverse(rootOperationTileSizes)); - if (failed(materializeStaticLaunchInformation(funcOp, - workloadPerWorkgroup)) || - failed( - updateWorkGroupSize(funcOp, launchConfig.getWorkgroupSize()))) { - funcOp.emitOpError("failed to materialize static launch information"); - return signalPassFailure(); - } - } - - LLVM_DEBUG({ - llvm::dbgs() << "--- After first level of tiling and distribution ---\n"; - funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); - llvm::dbgs() << "\n\n"; - }); - if (options.useWorkgroupMemory) { // The promotion patterns are put separate from the tiling patterns to // make sure that the allocated scratchspace memory is constant sizes diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp b/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp index 0ea3a74f1285..0dfe6cde7228 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp +++ b/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp @@ -83,9 +83,15 @@ static void addLinalgToSPIRVPasses(OpPassManager &pm, // - The Linalg op is kept untouched. // //===--------------------------------------------------------------------===// - if (!options.usingLinalgOnTensors) { + if (options.usingLinalgOnTensors) { + // flow.dispatch.workgroups performed abstract tiling and distribution. Make + // them concrete now since we know the target and settings now. + pm.addPass(createConcretizeTileAmongWorkgroupsPass(options)); + } else { pm.addPass(createSplitDispatchFunctionPass()); + pm.addPass(createTileAndDistributeAmongWorkgroupsPass(options)); } + pm.addPass(createLinalgTileAndFusePass(options)); if (options.vectorizeMemref) { pm.nest().addNestedPass( diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Passes.h b/iree/compiler/Conversion/LinalgToSPIRV/Passes.h index 1c29576e9433..bce9f2efd7c0 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/Passes.h +++ b/iree/compiler/Conversion/LinalgToSPIRV/Passes.h @@ -81,6 +81,16 @@ createFoldProcessorIDUsesPass(); std::unique_ptr> createMaterializeEntryPointsPass(); +/// Creates a pass to concretize hal.interface.workgroup.* ops with concrete +/// tiling and distribution scheme. +std::unique_ptr> +createConcretizeTileAmongWorkgroupsPass(const SPIRVCodegenOptions &options); + +/// Tiles and distributes Linalg operations on buffers among multiple +/// workgroups. +std::unique_ptr> +createTileAndDistributeAmongWorkgroupsPass(const SPIRVCodegenOptions &options); + //===----------------------------------------------------------------------===// // Pipelines //===----------------------------------------------------------------------===// diff --git a/iree/compiler/Conversion/LinalgToSPIRV/VectorizeMemref.cpp b/iree/compiler/Conversion/LinalgToSPIRV/VectorizeMemref.cpp index 65febab28f3e..671512a8b5fa 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/VectorizeMemref.cpp +++ b/iree/compiler/Conversion/LinalgToSPIRV/VectorizeMemref.cpp @@ -126,6 +126,7 @@ class MemRefUsageAnalysis { void analyzeFunc(FuncOp funcOp); void analyzeAlloc(AllocOp allocOp); void analyzePlaceholder(IREE::PlaceholderOp placeholderOp); + void analyzeInterfaceBinding(IREE::HAL::InterfaceBindingSubspanOp bindingOp); llvm::DenseMap vectorization_size; llvm::DenseSet transferOps; }; @@ -136,6 +137,8 @@ MemRefUsageAnalysis::MemRefUsageAnalysis(mlir::Operation *op) { if (auto alloc = dyn_cast(op)) analyzeAlloc(alloc); if (auto placeholder = dyn_cast(op)) analyzePlaceholder(placeholder); + if (auto bindingOp = dyn_cast(op)) + analyzeInterfaceBinding(bindingOp); }); } @@ -159,6 +162,15 @@ void MemRefUsageAnalysis::analyzePlaceholder( } } +void MemRefUsageAnalysis::analyzeInterfaceBinding( + IREE::HAL::InterfaceBindingSubspanOp bindingOp) { + SmallVector vectorUses; + if (unsigned vectorSize = isMemRefAndVectorizable(bindingOp, vectorUses)) { + vectorization_size.insert(std::make_pair(bindingOp, vectorSize)); + transferOps.insert(vectorUses.begin(), vectorUses.end()); + } +} + void MemRefUsageAnalysis::analyzeAlloc(AllocOp allocOp) { SmallVector vectorUses; if (unsigned vectorSize = isMemRefAndVectorizable(allocOp, vectorUses)) { @@ -363,6 +375,26 @@ class ProcessPlaceHolder final } }; +class ProcessInterfaceBinding final + : public MemRefConversionPattern { + public: + using MemRefConversionPattern< + IREE::HAL::InterfaceBindingSubspanOp>::MemRefConversionPattern; + + LogicalResult matchAndRewrite( + IREE::HAL::InterfaceBindingSubspanOp bindingOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto memrefType = bindingOp.getType().dyn_cast(); + if (!memrefType) return failure(); + auto vecMemRef = getVectorizedMemRefType(rewriter, bindingOp.getResult()); + if (!vecMemRef) return failure(); + rewriter.replaceOpWithNewOp( + bindingOp, *vecMemRef, bindingOp.binding(), bindingOp.byte_offset(), + bindingOp.byte_length()); + return success(); + } +}; + class VectorizeMemRefPass final : public PassWrapper> { void runOnOperation() override; @@ -410,8 +442,8 @@ void VectorizeMemRefPass::runOnOperation() { OwningRewritePatternList patterns; patterns.insert(context, - *memrefUsageAnalysis); + ProcessAlloc, ProcessPlaceHolder, ProcessInterfaceBinding>( + context, *memrefUsageAnalysis); ConversionTarget target(*context); target.addDynamicallyLegalOp([&](FuncOp op) { @@ -426,6 +458,10 @@ void VectorizeMemRefPass::runOnOperation() { [&](IREE::PlaceholderOp placeholder) { return !memrefUsageAnalysis->vectorizeMemRef(placeholder); }); + target.addDynamicallyLegalOp( + [&](IREE::HAL::InterfaceBindingSubspanOp bindingOp) { + return !memrefUsageAnalysis->vectorizeMemRef(bindingOp); + }); target.markUnknownOpDynamicallyLegal([&](Operation *op) { if (isa(op)) return !memrefUsageAnalysis->transferConvert(op); diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/batch_matmul_vectorization.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/batch_matmul_vectorization.mlir index 508f72e6ad12..c597271a70c2 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/batch_matmul_vectorization.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/batch_matmul_vectorization.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-linalg-tile-and-fuse,canonicalize,cse))" -iree-spirv-enable-vectorization %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-spirv-linalg-tile-and-distribute,iree-codegen-linalg-tile-and-fuse,canonicalize,cse))" -iree-spirv-enable-vectorization %s | IreeFileCheck %s hal.executable @batch_matmul_static_shape attributes {sym_visibility = "private"} { hal.interface @legacy_io { diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/concretize_tile_among_workgroups.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/concretize_tile_among_workgroups.mlir new file mode 100644 index 000000000000..d863939ede90 --- /dev/null +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/concretize_tile_among_workgroups.mlir @@ -0,0 +1,225 @@ +// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-spirv-concretize-tile-among-workgroups))" -iree-spirv-tile-size=16,4,4 -iree-spirv-workgroup-size=4,4,1 %s | IreeFileCheck %s + +hal.executable @conv2d_static_shape attributes {sym_visibility = "private"} { + hal.interface @legacy_io { + hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" + hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" + hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" + } + hal.executable.target @vulkan_spirv, filter="vulkan*" { + hal.executable.entry_point @conv2d_static_shape attributes { + interface = @legacy_io, ordinal = 0 : i32, + signature = (!flow.dispatch.input<1x225x225x16xf32>, !flow.dispatch.input<3x3x16x32xf32>, !flow.dispatch.output<1x112x112x32xf32>) -> ()} + module attributes {spv.target_env = #spv.target_env<#spv.vce, ARM:IntegratedGPU, {}>} { + func @conv2d_static_shape() { + %cst = constant 0.000000e+00 : f32 + %c32 = constant 32 : index + %c112 = constant 112 : index + %c0 = constant 0 : index + %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : memref<1x225x225x16xf32> + %1 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : memref<3x3x16x32xf32> + %2 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : memref<1x112x112x32xf32> + %workgroup_size_x = hal.interface.workgroup.size[0] : index + %workgroup_size_y = hal.interface.workgroup.size[1] : index + %workgroup_size_z = hal.interface.workgroup.size[2] : index + %workgroup_id_x = hal.interface.workgroup.id[0] : index + %workgroup_count_x = hal.interface.workgroup.count[0] : index + %workgroup_id_y = hal.interface.workgroup.id[1] : index + %workgroup_count_y = hal.interface.workgroup.count[1] : index + %workgroup_id_z = hal.interface.workgroup.id[2] : index + %workgroup_count_z = hal.interface.workgroup.count[2] : index + %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_z, %workgroup_size_z] + %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_z, %workgroup_size_z] + scf.for %arg0 = %3 to %c112 step %4 { + %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y] + %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y] + scf.for %arg1 = %5 to %c112 step %6 { + %7 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x] + %8 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x] + scf.for %arg2 = %7 to %c32 step %8 { + %9 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg0) + %10 = affine.min affine_map<(d0)[s0] -> (s0 * 2 + 1, d0 * -2 + 225)>(%arg0)[%workgroup_size_z] + %11 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg1) + %12 = affine.min affine_map<(d0)[s0] -> (s0 * 2 + 1, d0 * -2 + 225)>(%arg1)[%workgroup_size_y] + %13 = subview %0[0, %9, %11, 0] [1, %10, %12, 16] [1, 1, 1, 1] : memref<1x225x225x16xf32> to memref<1x?x?x16xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 810000 + s0 + d1 * 3600 + d2 * 16 + d3)>> + %14 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 32)>(%arg2)[%workgroup_size_x] + %15 = subview %1[0, 0, 0, %arg2] [3, 3, 16, %14] [1, 1, 1, 1] : memref<3x3x16x32xf32> to memref<3x3x16x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1536 + s0 + d1 * 512 + d2 * 32 + d3)>> + %16 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 112)>(%arg0)[%workgroup_size_z] + %17 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 112)>(%arg1)[%workgroup_size_y] + %18 = subview %2[0, %arg0, %arg1, %arg2] [1, %16, %17, %14] [1, 1, 1, 1] : memref<1x112x112x32xf32> to memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 401408 + s0 + d1 * 3584 + d2 * 32 + d3)>> + linalg.fill(%18, %cst) : memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 401408 + s0 + d1 * 3584 + d2 * 32 + d3)>>, f32 + linalg.conv_2d_input_nhwc_filter_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%13, %15 : memref<1x?x?x16xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 810000 + s0 + d1 * 3600 + d2 * 16 + d3)>>, memref<3x3x16x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1536 + s0 + d1 * 512 + d2 * 32 + d3)>>) outs(%18 : memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 401408 + s0 + d1 * 3584 + d2 * 32 + d3)>>) + } + } + } + return + } + hal.interface @legacy_io attributes {sym_visibility = "private"} { + hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" + hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" + hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" + } + } + } +} + +// Check that for a fully static shaped dispatch region, we can: +// 1) Generate static constant workgroup counts, +// 2) Replace hal.interface.workgroup.{size|count} ops with constants, +// 3) Canonicalize loops and subview ops. + +// CHECK: #[[MULMAP:.+]] = affine_map<()[s0, s1] -> (s0 * s1)> +// CHECK: #[[MAP0:.+]] = affine_map<(d0) -> (d0 * 2)> +// CHECK: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (9, d0 * -2 + 225)> +// CHECK: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (16, -d0 + 32)> +// CHECK: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (4, -d0 + 112)> + +// CHECK: hal.executable.entry_point @conv2d_static_shape +// CHECK: %[[C2:.+]] = constant 2 : index +// CHECK: %[[C28_0:.+]] = constant 28 : index +// CHECK: %[[C28_1:.+]] = constant 28 : index +// CHECK: hal.return %[[C2]], %[[C28_0]], %[[C28_1]] : index, index, index + +// CHECK: func @conv2d_static_shape() +// CHECK-SAME: spv.entry_point_abi = {local_size = dense<[4, 4, 1]> : vector<3xi32>} + +// CHECK: %[[ID_X:.+]] = hal.interface.workgroup.id[0] : index +// CHECK: %[[ID_Y:.+]] = hal.interface.workgroup.id[1] : index +// CHECK: %[[ID_Z:.+]] = hal.interface.workgroup.id[2] : index + +// CHECK: %[[Z_MUL_4:.+]] = affine.apply #[[MULMAP]]()[%[[ID_Z]], %c4] +// CHECK: %[[Y_MUL_4:.+]] = affine.apply #[[MULMAP]]()[%[[ID_Y]], %c4] +// CHECK: %[[X_MUL_16:.+]] = affine.apply #[[MULMAP]]()[%[[ID_X]], %c16] + +// CHECK: %[[Z_OFFSET:.+]] = affine.apply #[[MAP0]](%[[Z_MUL_4]]) +// CHECK: %[[Z_SIZE:.+]] = affine.min #[[MAP1]](%[[Z_MUL_4]])[%c4] +// CHECK: %[[Y_OFFSET:.+]] = affine.apply #[[MAP0]](%[[Y_MUL_4]]) +// CHECK: %[[Y_SIZE:.+]] = affine.min #[[MAP1]](%[[Y_MUL_4]])[%c4] + +// CHECK: %[[INPUT:.+]] = subview %{{.+}}[0, %[[Z_OFFSET]], %[[Y_OFFSET]], 0] [1, %[[Z_SIZE]], %[[Y_SIZE]], 16] [1, 1, 1, 1] : memref<1x225x225x16xf32> to memref<1x?x?x16xf32, {{.+}}> + +// CHECK: %[[X_SIZE:.+]] = affine.min #[[MAP2]](%[[X_MUL_16]])[%c16] + +// CHECK: %[[FILTER:.+]] = subview %{{.+}}[0, 0, 0, %[[X_MUL_16]]] [3, 3, 16, %[[X_SIZE]]] [1, 1, 1, 1] : memref<3x3x16x32xf32> to memref<3x3x16x?xf32, {{.+}}> + +// CHECK: %[[Z_SIZE:.+]] = affine.min #[[MAP3]](%[[Z_MUL_4]])[%c4] +// CHECK: %[[Y_SIZE:.+]] = affine.min #[[MAP3]](%[[Y_MUL_4]])[%c4] +// CHECK: %[[OUTPUT:.+]] = subview %{{.+}}[0, %[[Z_MUL_4]], %[[Y_MUL_4]], %[[X_MUL_16]]] [1, %[[Z_SIZE]], %[[Y_SIZE]], %[[X_SIZE]]] [1, 1, 1, 1] : memref<1x112x112x32xf32> to memref<1x?x?x?xf32, {{.+}}> + +// CHECK: linalg.fill(%[[OUTPUT]], %{{.+}}) +// CHECK: linalg.conv_2d_input_nhwc_filter_hwcf {dilations = dense<1> : tensor<2xi64>, is_root_op, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[FILTER]] : memref<1x?x?x16xf32, {{.+}}>, memref<3x3x16x?xf32, {{.+}}>) outs(%[[OUTPUT]] : memref<1x?x?x?xf32, {{.+}}>) + +// ----- + +hal.executable @matmul_dynamic_shape attributes {sym_visibility = "private"} { + hal.interface @legacy_io { + hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" + hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" + hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" + } + hal.executable.target @vulkan_spirv, filter="vulkan*" { + hal.executable.entry_point @matmul_dynamic_shape attributes { + interface = @legacy_io, ordinal = 0 : i32, + signature = (!flow.dispatch.input<1x225x225x16xf32>, !flow.dispatch.input<3x3x16x32xf32>, !flow.dispatch.output<1x112x112x32xf32>) -> ()} + module attributes {spv.target_env = #spv.target_env<#spv.vce, ARM:IntegratedGPU, {}>} { + func @matmul_dynamic_shape() { + %cst = constant 0.000000e+00 : f32 + %c0 = constant 0 : index + %0 = hal.interface.load.constant offset = 0 : index + %1 = hal.interface.load.constant offset = 1 : index + %2 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : memref + %3 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : memref + %4 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : memref + %5 = hal.interface.load.constant offset = 2 : index + %6 = hal.interface.load.constant offset = 3 : index + %7 = hal.interface.load.constant offset = 4 : index + %8 = hal.interface.load.constant offset = 5 : index + %9 = hal.interface.load.constant offset = 6 : index + %10 = hal.interface.load.constant offset = 7 : index + %11 = shapex.make_ranked_shape %5, %6 : (index, index) -> !shapex.ranked_shape<[?,?]> + %12 = shapex.tie_shape %2, %11 : memref, !shapex.ranked_shape<[?,?]> + %13 = shapex.make_ranked_shape %7, %8 : (index, index) -> !shapex.ranked_shape<[?,?]> + %14 = shapex.tie_shape %3, %13 : memref, !shapex.ranked_shape<[?,?]> + %15 = shapex.make_ranked_shape %9, %10 : (index, index) -> !shapex.ranked_shape<[?,?]> + %16 = shapex.tie_shape %4, %15 : memref, !shapex.ranked_shape<[?,?]> + %workgroup_size_x = hal.interface.workgroup.size[0] : index + %workgroup_size_y = hal.interface.workgroup.size[1] : index + %workgroup_id_x = hal.interface.workgroup.id[0] : index + %workgroup_count_x = hal.interface.workgroup.count[0] : index + %workgroup_id_y = hal.interface.workgroup.id[1] : index + %workgroup_count_y = hal.interface.workgroup.count[1] : index + %17 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y] + %18 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y] + scf.for %arg0 = %17 to %5 step %18 { + %19 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x] + %20 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x] + scf.for %arg1 = %19 to %8 step %20 { + %21 = affine.min affine_map<(d0)[s0, s1] -> (s1, -d0 + s0)>(%arg0)[%5, %workgroup_size_y] + %22 = subview %12[%arg0, 0] [%21, %6] [1, 1] : memref to memref (d0 * s1 + s0 + d1)>> + %23 = affine.min affine_map<(d0)[s0, s1] -> (s1, -d0 + s0)>(%arg1)[%8, %workgroup_size_x] + %24 = subview %14[0, %arg1] [%7, %23] [1, 1] : memref to memref (d0 * s1 + s0 + d1)>> + %25 = affine.min affine_map<(d0)[s0, s1] -> (s1, -d0 + s0)>(%arg0)[%0, %workgroup_size_y] + %26 = affine.min affine_map<(d0)[s0, s1] -> (s1, -d0 + s0)>(%arg1)[%1, %workgroup_size_x] + %27 = subview %16[%arg0, %arg1] [%25, %26] [1, 1] : memref to memref (d0 * s1 + s0 + d1)>> + linalg.fill(%27, %cst) {__internal_linalg_transform__ = "workgroup"} : memref (d0 * s1 + s0 + d1)>>, f32 + linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%22, %24 : memref (d0 * s1 + s0 + d1)>>, memref (d0 * s1 + s0 + d1)>>) outs(%27 : memref (d0 * s1 + s0 + d1)>>) + } + } + return + } + hal.interface @legacy_io attributes {sym_visibility = "private"} { + hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" + hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" + hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" + } + } + } +} + +// Check that for a fully dynamic shaped dispatch region, we can: +// 1) Generate symbolic workgroup counts, +// 2) Replace hal.interface.workgroup.size (but not .count) ops with constants. + +// CHECK: #[[DIV16MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)> +// CHECK: #[[DIV4MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)> +// CHECK: #[[MULMAP:.+]] = affine_map<()[s0, s1] -> (s0 * s1)> +// CHECK: #[[YBOUNDMAP:.+]] = affine_map<(d0)[s0, s1] -> (4, -d0 + s0)> +// CHECK: #[[XBOUNDMAP:.+]] = affine_map<(d0)[s0, s1] -> (16, -d0 + s0)> + +// CHECK: hal.executable.entry_point @matmul_dynamic_shape +// CHECK: ^{{.+}}(%[[BBARG0:.+]]: index, %[[BBARG1:.+]]: index, %{{.+}}: index): +// CHECK: %c1 = constant 1 : index +// CHECK: %[[SIZE0:.+]] = affine.apply #[[DIV16MAP]]()[%[[BBARG0]]] +// CHECK: %[[SIZE1:.+]] = affine.apply #[[DIV4MAP]]()[%[[BBARG1]]] +// CHECK: hal.return %[[SIZE0]], %[[SIZE1]], %c1 + +// CHECK: func @matmul_dynamic_shape() +// CHECK-SAME: spv.entry_point_abi = {local_size = dense<[4, 4, 1]> : vector<3xi32>} + +// CHECK: %[[C_DIM0:.+]] = hal.interface.load.constant offset = 0 : index +// CHECK: %[[C_DIM1:.+]] = hal.interface.load.constant offset = 1 : index +// CHECK: %[[A_DIM0:.+]] = hal.interface.load.constant offset = 2 : index +// CHECK: %[[A_DIM1:.+]] = hal.interface.load.constant offset = 3 : index +// CHECK: %[[B_DIM0:.+]] = hal.interface.load.constant offset = 4 : index +// CHECK: %[[B_DIM1:.+]] = hal.interface.load.constant offset = 5 : index + +// CHECK: %[[ID_X:.+]] = hal.interface.workgroup.id[0] : index +// CHECK: %[[COUNT_X:.+]] = hal.interface.workgroup.count[0] : index +// CHECK: %[[ID_Y:.+]] = hal.interface.workgroup.id[1] : index +// CHECK: %[[COUNT_Y:.+]] = hal.interface.workgroup.count[1] : index + +// CHECK: %[[Y_LB:.+]] = affine.apply #[[MULMAP]]()[%[[ID_Y]], %c4] +// CHECK: %[[Y_STEP:.+]] = affine.apply #[[MULMAP]]()[%[[COUNT_Y]], %c4] +// CHECK: scf.for %[[IV_Y:.+]] = %[[Y_LB]] to %[[A_DIM0]] step %[[Y_STEP]] +// CHECK: %[[X_LB:.+]] = affine.apply #[[MULMAP]]()[%[[ID_X]], %c16] +// CHECK: %[[X_STEP:.+]] = affine.apply #[[MULMAP]]()[%[[COUNT_X]], %c16] +// CHECK: scf.for %[[IV_X:.+]] = %[[X_LB]] to %[[B_DIM1]] step %[[X_STEP]] +// CHECK: %[[Y_SIZE:.+]] = affine.min #[[YBOUNDMAP]](%[[IV_Y]])[%[[A_DIM0]], %c4] +// CHECK: %[[A_TILE:.+]] = subview %{{.+}}[%[[IV_Y]], 0] [%[[Y_SIZE]], %[[A_DIM1]]] [1, 1] : memref to memref +// CHECK: %[[X_SIZE:.+]] = affine.min #[[XBOUNDMAP]](%[[IV_X]])[%[[B_DIM1]], %c16] +// CHECK: %[[B_TILE:.+]] = subview %{{.+}}[0, %[[IV_X]]] [%[[B_DIM0]], %[[X_SIZE]]] [1, 1] : memref to memref +// CHECK: %[[Y_SIZE:.+]] = affine.min #[[YBOUNDMAP]](%[[IV_Y]])[%[[C_DIM0]], %c4] +// CHECK: %[[X_SIZE:.+]] = affine.min #[[XBOUNDMAP]](%[[IV_X]])[%[[C_DIM1]], %c16] +// CHECK: %[[C_TILE:.+]] = subview %{{.+}}[%[[IV_Y]], %[[IV_X]]] [%[[Y_SIZE]], %[[X_SIZE]]] [1, 1] : memref to memref +// CHECK: linalg.fill(%[[C_TILE]], %cst) {__internal_linalg_transform__ = "workgroup"} : memref, f32 +// CHECK: linalg.matmul {__internal_linalg_transform__ = "workgroup", is_root_op} ins(%[[A_TILE]], %[[B_TILE]] : memref, memref) outs(%[[C_TILE]] : memref) diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/fold-gpu-procid-uses.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/fold-gpu-procid-uses.mlir index fe9bbabefc44..8adf600f5a96 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/fold-gpu-procid-uses.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/fold-gpu-procid-uses.mlir @@ -34,6 +34,40 @@ hal.executable @fold_block_id attributes {sym_visibility = "private"} { // ----- +hal.executable @fold_interface_workgroup_id attributes {sym_visibility = "private"} { + hal.interface @legacy_io { + } + hal.executable.target @vulkan, filter="vulkan*" { + hal.executable.entry_point @fold_interface_workgroup_id attributes { + interface = @legacy_io, ordinal = 0 : i32, + signature = () -> ()} { + ^bb0(%arg0 : index, %arg1 : index, %arg2 : index): + %x = constant 112: index + %y = constant 42: index + %z = constant 1: index + hal.return %x, %y, %z: index, index, index + } + module { + func @fold_interface_workgroup_id() -> (index, index, index) { + %0 = hal.interface.workgroup.id[0] : index + %1 = hal.interface.workgroup.id[1] : index + %2 = hal.interface.workgroup.id[2] : index + %3 = affine.min affine_map<()[s0] -> (3, s0 * -2 + 225)>()[%0] + %4 = affine.min affine_map<()[s0] -> (8, s0 * -1 + s0 * -1 + s0 * -1 + 131)>()[%2] + %5 = affine.min affine_map<()[s0] -> (11, s0 + 15)>()[%3] + return %3, %4, %5: index, index, index + } + } + } +} +// CHECK-LABEL: func @fold_interface_workgroup_id() +// CHECK-DAG: %[[C3:.+]] = constant 3 +// CHECK-DAG: %[[C8:.+]] = constant 8 +// CHECK-DAG: %[[C11:.+]] = constant 11 +// CHECK-DAG: return %[[C3]], %[[C8]], %[[C11]] + +// ----- + hal.executable @fold_thread_id attributes {sym_visibility = "private"} { hal.interface @legacy_io { } diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir index d8da2427ec4b..f5a2b68858ba 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-linalg-tile-and-fuse))" -iree-spirv-enable-vectorization -canonicalize -cse %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-spirv-linalg-tile-and-distribute,iree-codegen-linalg-tile-and-fuse))" -iree-spirv-enable-vectorization -canonicalize -cse %s | IreeFileCheck %s // TODO(GH-4901): Convert these tests back to use dynamic shapes when linalg on tensors becomes default. hal.executable @conv_no_padding attributes {sym_visibility = "private"} { @@ -439,148 +439,3 @@ hal.executable @three_op_fusion attributes {sym_visibility = "private"} { // CHECK-SAME: ) // CHECK-SAME: outs(%[[SV_RET0]] // CHECK-SAME: ) - -// ----- - -// TODO(GH-4901): Convert these tests back to use dynamic shapes when linalg on tensors becomes default. -hal.executable @conv_tiled_and_vectorized attributes {sym_visibility = "private"} { - hal.interface @legacy_io { - hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" - hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" - hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" - } - hal.executable.target @vulkan, filter="dylib*" { - hal.executable.entry_point @conv_tiled_and_vectorized attributes { - interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input<1x225x225x16xf32>, !flow.dispatch.input<3x3x16x32xf32>, - !flow.dispatch.output<1x112x112x32xf32>) -> ()} - module attributes { - spv.target_env = #spv.target_env<#spv.vce, ARM:IntegratedGPU, {max_compute_shared_memory_size = 32768 : i32, max_compute_workgroup_invocations = 512 : i32, max_compute_workgroup_size = dense<512> : vector<3xi32>, subgroup_size = 16 : i32}> - } { - func @conv_tiled_and_vectorized() { - %cst = constant 0.000000e+00 : f32 - %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<1x112x112x32xf32> - %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<1x225x225x16xf32> - %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x3x16x32xf32> - linalg.fill(%0, %cst) : memref<1x112x112x32xf32>, f32 - linalg.conv_2d_input_nhwc_filter_hwcf {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>} - ins (%1, %2: memref<1x225x225x16xf32>, memref<3x3x16x32xf32>) - outs (%0: memref<1x112x112x32xf32>) - return - } - - hal.interface @legacy_io attributes {sym_visibility = "private"} { - hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" - hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" - hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" - } - } - } -} -// CHECK-LABEL: func @conv_tiled_and_vectorized() - -// For linalg.fill -// CHECK-COUNT-4: vector.transfer_write - -// For linalg.conv_2d_input_nhwc_filter_hwcf -// CHECK-COUNT-4: vector.transfer_read - -// check tiling loop along filter height/width and input channel -// CHECK: scf.for %{{.*}} = %c0 to %c3 step %c1 -// CHECK-SAME: -> (vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>) -// CHECK: scf.for %{{.*}} = %c0 to %c3 step %c1 -// CHECK-SAME: -> (vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>) -// CHECK: scf.for %{{.*}} = %c0 to %c16 step %c4 -// CHECK-SAME: -> (vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>) - -// CHECK-COUNT-16: vector.contract - -// CHECK-COUNT-3: scf.yield - -// For linalg.conv_2d_input_nhwc_filter_hwcf -// CHECK-COUNT-4: vector.transfer_write - -// ----- - -hal.executable @depthwise_conv2d_2452x2423_valid_stride_2 attributes {sym_visibility = "private"} { - hal.interface @legacy_io { - hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" - hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" - hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" - } - hal.executable.target @vulkan_spirv, filter="vulkan*" { - hal.executable.entry_point @depthwise_conv2d_2452x2423_valid_stride_2 attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (tensor<2x4x5x2xf32>, tensor<2x4x2x3xf32>) -> tensor<2x2x1x6xf32>} - module attributes {spv.target_env = #spv.target_env<#spv.vce, SwiftShader:CPU, {cooperative_matrix_properties_nv = [], max_compute_shared_memory_size = 16384 : i32, max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>, subgroup_size = 4 : i32}>} { - func @depthwise_conv2d_2452x2423_valid_stride_2() { - %cst = constant 0.000000e+00 : f32 - %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<2x2x1x2x3xf32> - %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<2x4x5x2xf32> - %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<2x4x2x3xf32> - linalg.fill(%0, %cst) : memref<2x2x1x2x3xf32>, f32 - linalg.depthwise_conv_2d_input_nhwc_filter_hwcf {strides = dense<2> : tensor<2xi64>} ins(%1, %2 : memref<2x4x5x2xf32>, memref<2x4x2x3xf32>) outs(%0 : memref<2x2x1x2x3xf32>) - return - } - hal.interface @legacy_io attributes {sym_visibility = "private"} { - hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" - hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" - hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" - } - } - } -} - -// CHECK-LABEL: func @depthwise_conv2d_2452x2423_valid_stride_2() -// CHECK: linalg.fill -// CHECK: linalg.generic -// CHECK-NOT: linalg.depthwise_conv_2d_input_nhwc_filter_hwcf - -// ----- - -hal.executable @conv_tiled_and_vectorized attributes {sym_visibility = "private"} { - hal.interface @legacy_io { - hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" - hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" - hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" - } - hal.executable.target @vulkan, filter="dylib*" { - hal.executable.entry_point @depthwise_conv_tiled_and_vectorized attributes { - interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input<1x113x113x96xf32>, !flow.dispatch.input<3x3x96xf32>, - !flow.dispatch.output<1x56x56x96xf32>) -> ()} - module attributes { - spv.target_env = #spv.target_env<#spv.vce, ARM:IntegratedGPU, {max_compute_shared_memory_size = 32768 : i32, max_compute_workgroup_invocations = 512 : i32, max_compute_workgroup_size = dense<512> : vector<3xi32>, subgroup_size = 16 : i32}> - } { - func @depthwise_conv_tiled_and_vectorized() { - %cst = constant 0.000000e+00 : f32 - %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<1x56x56x96xf32> - %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<1x113x113x96xf32> - %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x3x96xf32> - linalg.fill(%0, %cst) : memref<1x56x56x96xf32>, f32 - linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<2> : tensor<2xi64>} ins(%1, %2 : memref<1x113x113x96xf32>, memref<3x3x96xf32>) outs(%0 : memref<1x56x56x96xf32>) - return - } - } - } -} - -// CHECK-LABEL: func @depthwise_conv_tiled_and_vectorized() - -// For linalg.fill -// CHECK: vector.transfer_write - -// For linalg.depthwise_conv_2d_input_nhwc_filter_hwc -// CHECK: vector.transfer_read - -// check tiling loop along filter height/width and input channel -// CHECK: scf.for %{{.+}} = %c0 to %c3 step %c1 -// CHECK-SAME: -> (vector<4xf32>) -// CHECK: scf.for %{{.+}} = %c0 to %c3 step %c1 -// CHECK-SAME: -> (vector<4xf32>) - - -// CHECK: vector.fma - -// CHECK-COUNT-2: scf.yield - -// For linalg.depthwise_conv_2d_input_nhwc_filter_hwc -// CHECK: vector.transfer_write diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/materialize_launch_configuration.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/materialize_launch_configuration.mlir index 8151a358e43f..c55b112f32fa 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/materialize_launch_configuration.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/materialize_launch_configuration.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-linalg-tile-and-fuse))" -iree-codegen-spirv-experimental-linalg-on-tensors -cse -canonicalize -split-input-file %s | IreeFileCheck %s +// RUN: iree-opt -pass-pipeline="hal.executable(hal.executable.target(iree-spirv-concretize-tile-among-workgroups))" -iree-codegen-spirv-experimental-linalg-on-tensors -cse -canonicalize -split-input-file %s | IreeFileCheck %s hal.executable @matmul_tensors attributes {sym_visibility = "private"} { hal.interface @legacy_io { diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_fused_vectorization.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_fused_vectorization.mlir index 88251f44d56a..b07f50ad08ff 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_fused_vectorization.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_fused_vectorization.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-linalg-tile-and-fuse,canonicalize,cse))" -iree-spirv-enable-vectorization %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-spirv-linalg-tile-and-distribute,iree-codegen-linalg-tile-and-fuse,canonicalize,cse))" -iree-spirv-enable-vectorization %s | IreeFileCheck %s hal.executable @matmul_static_shape attributes {sym_visibility = "private"} { hal.interface @legacy_io { @@ -61,67 +61,3 @@ hal.executable @matmul_static_shape attributes {sym_visibility = "private"} { // CHECK: scf.yield // CHECK-COUNT-8: vector.transfer_write %[[FOR_RES]] // CHECK: return - -// ----- - -hal.executable @matmul_static_shape_f16 attributes {sym_visibility = "private"} { - hal.interface @legacy_io { - hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" - hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" - hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" - } - hal.executable.target @vulkan, filter="dylib*" { - hal.executable.entry_point @matmul_static_shape_f16 attributes { - interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input, !flow.dispatch.input, - !flow.dispatch.output) -> ()} - module attributes { - spv.target_env = - #spv.target_env<#spv.vce, - ARM:IntegratedGPU, - {max_compute_shared_memory_size = 32768 : i32, - max_compute_workgroup_invocations = 512 : i32, - max_compute_workgroup_size = dense<512> : vector<3xi32>, - subgroup_size = 16 : i32}>} { - func @matmul_static_shape_f16() - attributes {vkspv.num_workgroups_fn = @matmul_static_shape__num_workgroups__} { - %arg0 = iree.placeholder for "interface buffer" - {binding = @legacy_io::@arg0, operand_result_num = 0 : i32} : memref<4096x4096xf16> - %arg1 = iree.placeholder for "interface buffer" - {binding = @legacy_io::@arg1, operand_result_num = 1 : i32} : memref<4096x4096xf16> - %ret0 = iree.placeholder for "interface buffer" - {binding = @legacy_io::@ret0, operand_result_num = 2 : i32} : memref<4096x4096xf16> - %cst = constant 0.000000e+00 : f16 - linalg.fill(%ret0, %cst) : memref<4096x4096xf16>, f16 - linalg.matmul ins(%arg0, %arg1 : memref<4096x4096xf16>, memref<4096x4096xf16>) - outs(%ret0 : memref<4096x4096xf16>) - return - } - func private @matmul_static_shape__num_workgroups__ - (!shapex.ranked_shape<[4096, 4096]>, !shapex.ranked_shape<[4096, 4096]>, - !shapex.ranked_shape<[4096, 4096]>) -> (index, index, index) - hal.interface @legacy_io attributes {sym_visibility = "private"} { - hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" - hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" - hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" - } - } - } -} -// CHECK-LABEL: func @matmul_static_shape_f16 -// CHECK-COUNT-16: vector.transfer_write -// CHECK-COUNT-16: vector.transfer_read -// CHECK: %[[FOR_RES:.+]]:16 = scf.for -// CHECK-COUNT-16: vector.transfer_read -// CHECK-COUNT-64: vector.contract -// CHECK: scf.yield -// CHECK-COUNT-16: vector.transfer_write %[[FOR_RES]] -// CHECK: return diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_vectorization.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_vectorization.mlir index 625b2b5ee41c..ee9eef28624c 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_vectorization.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_vectorization.mlir @@ -1,5 +1,5 @@ -// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-linalg-tile-and-fuse,canonicalize,cse))" -iree-spirv-enable-vectorization %s | IreeFileCheck %s -// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-linalg-tile-and-fuse,canonicalize,cse))" -iree-spirv-enable-vectorization -iree-spirv-use-workgroup-memory %s | IreeFileCheck %s -check-prefix=PROMOTE +// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-spirv-linalg-tile-and-distribute,iree-codegen-linalg-tile-and-fuse,canonicalize,cse))" -iree-spirv-enable-vectorization %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-spirv-linalg-tile-and-distribute,iree-codegen-linalg-tile-and-fuse,canonicalize,cse))" -iree-spirv-enable-vectorization -iree-spirv-use-workgroup-memory %s | IreeFileCheck %s -check-prefix=PROMOTE hal.executable @matmul_static_shape attributes {sym_visibility = "private"} { hal.interface @legacy_io { diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/memref_vecrotization.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/memref_vecrotization.mlir index f5a6a5410268..0af82c633e63 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/memref_vecrotization.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/memref_vecrotization.mlir @@ -132,3 +132,24 @@ hal.interface @legacy_io attributes {sym_visibility = "private"} { hal.interface.binding @arg0, set=1, binding=2, type="StorageBuffer", access="Read" hal.interface.binding @ret0, set=3, binding=4, type="StorageBuffer", access="Write" } + +// ----- + +func @vectorize_binding_subspan() { + %cst = constant 0.000000e+00 : f32 + %c0 = constant 0 : index + // CHECK: hal.interface.binding.subspan @legacy_io::@arg0[%c0] + // CHECK-SAME: memref<4096x1024xvector<4xf32>> + // CHECK: hal.interface.binding.subspan @legacy_io::@ret0[%c0] + // CHECK-SAME: memref<4096x1024xvector<4xf32>> + %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : memref<4096x4096xf32> + %1 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : memref<4096x4096xf32> + %mat = vector.transfer_read %0[%c0, %c0], %cst : memref<4096x4096xf32>, vector<32x8xf32> + vector.transfer_write %mat, %1[%c0, %c0] : vector<32x8xf32>, memref<4096x4096xf32> + return +} + +hal.interface @legacy_io attributes {sym_visibility = "private"} { + hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" + hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard" +} diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_conv.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_conv.mlir new file mode 100644 index 000000000000..165cd52394b8 --- /dev/null +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_conv.mlir @@ -0,0 +1,176 @@ +// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-spirv-concretize-tile-among-workgroups,iree-codegen-linalg-tile-and-fuse))" -iree-spirv-enable-vectorization -iree-codegen-spirv-experimental-linalg-on-tensors -canonicalize -cse %s | IreeFileCheck %s + +hal.executable @conv_static_shape_f32 attributes {sym_visibility = "private"} { + hal.interface @legacy_io { + hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" + hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" + hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" + } + hal.executable.target @vulkan_spirv, filter="vulkan*" { + hal.executable.entry_point @conv_static_shape_f32 attributes { + interface = @legacy_io, ordinal = 0 : i32, + signature = (!flow.dispatch.input<1x225x225x16xf32>, !flow.dispatch.input<3x3x16x32xf32>, !flow.dispatch.output<1x112x112x32xf32>) -> ()} + module attributes {spv.target_env = #spv.target_env<#spv.vce, ARM:IntegratedGPU, {}>} { + func @conv_static_shape_f32() { + %cst = constant 0.000000e+00 : f32 + %c32 = constant 32 : index + %c112 = constant 112 : index + %c0 = constant 0 : index + %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : memref<1x225x225x16xf32> + %1 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : memref<3x3x16x32xf32> + %2 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : memref<1x112x112x32xf32> + %workgroup_size_x = hal.interface.workgroup.size[0] : index + %workgroup_size_y = hal.interface.workgroup.size[1] : index + %workgroup_size_z = hal.interface.workgroup.size[2] : index + %workgroup_id_x = hal.interface.workgroup.id[0] : index + %workgroup_count_x = hal.interface.workgroup.count[0] : index + %workgroup_id_y = hal.interface.workgroup.id[1] : index + %workgroup_count_y = hal.interface.workgroup.count[1] : index + %workgroup_id_z = hal.interface.workgroup.id[2] : index + %workgroup_count_z = hal.interface.workgroup.count[2] : index + %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_z, %workgroup_size_z] + %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_z, %workgroup_size_z] + scf.for %arg0 = %3 to %c112 step %4 { + %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y] + %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y] + scf.for %arg1 = %5 to %c112 step %6 { + %7 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x] + %8 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x] + scf.for %arg2 = %7 to %c32 step %8 { + %9 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg0) + %10 = affine.min affine_map<(d0)[s0] -> (s0 * 2 + 1, d0 * -2 + 225)>(%arg0)[%workgroup_size_z] + %11 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg1) + %12 = affine.min affine_map<(d0)[s0] -> (s0 * 2 + 1, d0 * -2 + 225)>(%arg1)[%workgroup_size_y] + %13 = subview %0[0, %9, %11, 0] [1, %10, %12, 16] [1, 1, 1, 1] : memref<1x225x225x16xf32> to memref<1x?x?x16xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 810000 + s0 + d1 * 3600 + d2 * 16 + d3)>> + %14 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 32)>(%arg2)[%workgroup_size_x] + %15 = subview %1[0, 0, 0, %arg2] [3, 3, 16, %14] [1, 1, 1, 1] : memref<3x3x16x32xf32> to memref<3x3x16x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1536 + s0 + d1 * 512 + d2 * 32 + d3)>> + %16 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 112)>(%arg0)[%workgroup_size_z] + %17 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 112)>(%arg1)[%workgroup_size_y] + %18 = subview %2[0, %arg0, %arg1, %arg2] [1, %16, %17, %14] [1, 1, 1, 1] : memref<1x112x112x32xf32> to memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 401408 + s0 + d1 * 3584 + d2 * 32 + d3)>> + linalg.fill(%18, %cst) {__internal_linalg_transform__ = "workgroup"} : memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 401408 + s0 + d1 * 3584 + d2 * 32 + d3)>>, f32 + linalg.conv_2d_input_nhwc_filter_hwcf {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%13, %15 : memref<1x?x?x16xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 810000 + s0 + d1 * 3600 + d2 * 16 + d3)>>, memref<3x3x16x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1536 + s0 + d1 * 512 + d2 * 32 + d3)>>) outs(%18 : memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 401408 + s0 + d1 * 3584 + d2 * 32 + d3)>>) + } + } + } + return + } + hal.interface @legacy_io attributes {sym_visibility = "private"} { + hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" + hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" + hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" + } + } + } +} + +// CHECK-LABEL: func @conv_static_shape_f32() + +// For linalg.fill +// CHECK-COUNT-4: vector.transfer_write + +// For linalg.conv_2d_input_nhwc_filter_hwcf +// CHECK-COUNT-4: vector.transfer_read + +// check tiling loop along filter height/width and input channel +// CHECK: scf.for %{{.*}} = %c0 to %c3 step %c1 +// CHECK-SAME: -> (vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>) +// CHECK: scf.for %{{.*}} = %c0 to %c3 step %c1 +// CHECK-SAME: -> (vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>) +// CHECK: scf.for %{{.*}} = %c0 to %c16 step %c4 +// CHECK-SAME: -> (vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>) + +// CHECK-COUNT-16: vector.contract + +// CHECK-COUNT-3: scf.yield + +// For linalg.conv_2d_input_nhwc_filter_hwcf +// CHECK-COUNT-4: vector.transfer_write + +// ----- + +hal.executable @depthwise_conv_static_shape_f32 attributes {sym_visibility = "private"} { + hal.interface @legacy_io { + hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" + hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" + hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" + } + hal.executable.target @vulkan_spirv, filter="vulkan*" { + hal.executable.entry_point @depthwise_conv_static_shape_f32 attributes { + interface = @legacy_io, ordinal = 0 : i32, + signature = (!flow.dispatch.input<1x225x225x16xf32>, !flow.dispatch.input<3x3x16x32xf32>, !flow.dispatch.output<1x112x112x32xf32>) -> ()} + module attributes {spv.target_env = #spv.target_env<#spv.vce, ARM:IntegratedGPU, {}>} { + func @depthwise_conv_static_shape_f32() { + %cst = constant 0.000000e+00 : f32 + %c96 = constant 96 : index + %c56 = constant 56 : index + %c0 = constant 0 : index + %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : memref<1x113x113x96xf32> + %1 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : memref<3x3x1x96xf32> + %2 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : memref<1x56x56x96xf32> + %3 = linalg.reshape %1 [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>] : memref<3x3x1x96xf32> into memref<864xf32> + %4 = linalg.reshape %3 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>] : memref<864xf32> into memref<3x3x96xf32> + %workgroup_size_x = hal.interface.workgroup.size[0] : index + %workgroup_size_y = hal.interface.workgroup.size[1] : index + %workgroup_size_z = hal.interface.workgroup.size[2] : index + %workgroup_id_x = hal.interface.workgroup.id[0] : index + %workgroup_count_x = hal.interface.workgroup.count[0] : index + %workgroup_id_y = hal.interface.workgroup.id[1] : index + %workgroup_count_y = hal.interface.workgroup.count[1] : index + %workgroup_id_z = hal.interface.workgroup.id[2] : index + %workgroup_count_z = hal.interface.workgroup.count[2] : index + %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_z, %workgroup_size_z] + %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_z, %workgroup_size_z] + scf.for %arg0 = %5 to %c56 step %6 { + %7 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y] + %8 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y] + scf.for %arg1 = %7 to %c56 step %8 { + %9 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x] + %10 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x] + scf.for %arg2 = %9 to %c96 step %10 { + %11 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg0) + %12 = affine.min affine_map<(d0)[s0] -> (s0 * 2 + 1, d0 * -2 + 113)>(%arg0)[%workgroup_size_z] + %13 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg1) + %14 = affine.min affine_map<(d0)[s0] -> (s0 * 2 + 1, d0 * -2 + 113)>(%arg1)[%workgroup_size_y] + %15 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 96)>(%arg2)[%workgroup_size_x] + %16 = subview %0[0, %11, %13, %arg2] [1, %12, %14, %15] [1, 1, 1, 1] : memref<1x113x113x96xf32> to memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1225824 + s0 + d1 * 10848 + d2 * 96 + d3)>> + %17 = subview %4[0, 0, %arg2] [3, 3, %15] [1, 1, 1] : memref<3x3x96xf32> to memref<3x3x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 288 + s0 + d1 * 96 + d2)>> + %18 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 56)>(%arg0)[%workgroup_size_z] + %19 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 56)>(%arg1)[%workgroup_size_y] + %20 = subview %2[0, %arg0, %arg1, %arg2] [1, %18, %19, %15] [1, 1, 1, 1] : memref<1x56x56x96xf32> to memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 301056 + s0 + d1 * 5376 + d2 * 96 + d3)>> + linalg.fill(%20, %cst) {__internal_linalg_transform__ = "workgroup"} : memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 301056 + s0 + d1 * 5376 + d2 * 96 + d3)>>, f32 + linalg.depthwise_conv_2d_input_nhwc_filter_hwc {__internal_linalg_transform__ = "workgroup", strides = dense<2> : tensor<2xi64>} ins(%16, %17 : memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1225824 + s0 + d1 * 10848 + d2 * 96 + d3)>>, memref<3x3x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 288 + s0 + d1 * 96 + d2)>>) outs(%20 : memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 301056 + s0 + d1 * 5376 + d2 * 96 + d3)>>) + } + } + } + return + } + hal.interface @legacy_io attributes {sym_visibility = "private"} { + hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" + hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" + hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" + } + } + } +} + +// CHECK-LABEL: func @depthwise_conv_static_shape_f32() + +// For linalg.fill +// CHECK: vector.transfer_write + +// For linalg.depthwise_conv_2d_input_nhwc_filter_hwc +// CHECK: vector.transfer_read + +// check tiling loop along filter height/width and input channel +// CHECK: scf.for %{{.+}} = %c0 to %c3 step %c1 +// CHECK-SAME: -> (vector<4xf32>) +// CHECK: scf.for %{{.+}} = %c0 to %c3 step %c1 +// CHECK-SAME: -> (vector<4xf32>) + + +// CHECK: vector.fma + +// CHECK-COUNT-2: scf.yield + +// For linalg.depthwise_conv_2d_input_nhwc_filter_hwc +// CHECK: vector.transfer_write diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_matmul.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_matmul.mlir new file mode 100644 index 000000000000..74d43514e568 --- /dev/null +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_matmul.mlir @@ -0,0 +1,61 @@ +// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-spirv-concretize-tile-among-workgroups,iree-codegen-linalg-tile-and-fuse))" -iree-spirv-enable-vectorization -iree-codegen-spirv-experimental-linalg-on-tensors -canonicalize -cse %s | IreeFileCheck %s + +hal.executable @matmul_static_shape_f16 attributes {sym_visibility = "private"} { + hal.interface @legacy_io attributes {sym_visibility = "private"} { + hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" + hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" + hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" + } + hal.executable.target @vulkan_spirv, filter="vulkan*" { + hal.executable.entry_point @matmul_static_shape_f16 attributes { + interface = @legacy_io, ordinal = 0 : i32, + signature = (!flow.dispatch.input<1x225x225x16xf32>, !flow.dispatch.input<3x3x16x32xf32>, !flow.dispatch.output<1x112x112x32xf32>) -> ()} + module attributes {spv.target_env = #spv.target_env<#spv.vce, ARM:IntegratedGPU, {}>} { + func @matmul_static_shape_f16() { + %cst = constant 0.000000e+00 : f16 + %c0 = constant 0 : index + %c4096 = constant 4096 : index + %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : memref<4096x4096xf16> + %1 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : memref<4096x4096xf16> + %2 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : memref<4096x4096xf16> + %workgroup_size_x = hal.interface.workgroup.size[0] : index + %workgroup_size_y = hal.interface.workgroup.size[1] : index + %workgroup_id_x = hal.interface.workgroup.id[0] : index + %workgroup_count_x = hal.interface.workgroup.count[0] : index + %workgroup_id_y = hal.interface.workgroup.id[1] : index + %workgroup_count_y = hal.interface.workgroup.count[1] : index + %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y] + %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y] + scf.for %arg0 = %3 to %c4096 step %4 { + %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x] + %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x] + scf.for %arg1 = %5 to %c4096 step %6 { + %7 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 4096)>(%arg0)[%workgroup_size_y] + %8 = subview %0[%arg0, 0] [%7, 4096] [1, 1] : memref<4096x4096xf16> to memref (d0 * 4096 + s0 + d1)>> + %9 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 4096)>(%arg1)[%workgroup_size_x] + %10 = subview %2[%arg0, %arg1] [%7, %9] [1, 1] : memref<4096x4096xf16> to memref (d0 * 4096 + s0 + d1)>> + %11 = subview %1[0, %arg1] [4096, %9] [1, 1] : memref<4096x4096xf16> to memref<4096x?xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>> + linalg.fill(%10, %cst) {__internal_linalg_transform__ = "workgroup"} : memref (d0 * 4096 + s0 + d1)>>, f16 + linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%8, %11 : memref (d0 * 4096 + s0 + d1)>>, memref<4096x?xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>) outs(%10 : memref (d0 * 4096 + s0 + d1)>>) + } + } + return + } + hal.interface @legacy_io attributes {sym_visibility = "private"} { + hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" + hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" + hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" + } + } + } +} + +// CHECK-LABEL: func @matmul_static_shape_f16 +// CHECK-COUNT-16: vector.transfer_write +// CHECK-COUNT-16: vector.transfer_read +// CHECK: %[[FOR_RES:.+]]:16 = scf.for +// CHECK-COUNT-16: vector.transfer_read +// CHECK-COUNT-64: vector.contract +// CHECK: scf.yield +// CHECK-COUNT-16: vector.transfer_write %[[FOR_RES]] +// CHECK: return diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir index 5da02d9d228c..c96ce203934f 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-linalg-tile-and-fuse,canonicalize,cse))" -iree-spirv-use-workgroup-memory %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-spirv-linalg-tile-and-distribute,iree-codegen-linalg-tile-and-fuse,canonicalize,cse))" -iree-spirv-use-workgroup-memory %s | IreeFileCheck %s // TODO(GH-4901): Convert these tests back to use dynamic shapes when linalg on tensors becomes default. hal.executable @matmul_tile attributes {sym_visibility = "private"} { diff --git a/iree/compiler/Conversion/init_conversions.h b/iree/compiler/Conversion/init_conversions.h index 9e0f52459371..cb1c9eb71028 100644 --- a/iree/compiler/Conversion/init_conversions.h +++ b/iree/compiler/Conversion/init_conversions.h @@ -65,6 +65,7 @@ inline void registerLinalgToSPIRVPasses() { // LinalgToSPIRV createConvertToGPUPass(SPIRVCodegenOptions()); createFoldProcessorIDUsesPass(); + createTileAndDistributeAmongWorkgroupsPass(SPIRVCodegenOptions()); createLinalgTileAndFusePass(SPIRVCodegenOptions()); createSplitDispatchFunctionPass(); createVectorToGPUPass(); diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp index 739a3c1617c4..f84b74350f17 100644 --- a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp @@ -20,8 +20,10 @@ #include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h" #include "iree/compiler/Dialect/Shape/IR/ShapeOps.h" #include "llvm/ADT/STLExtras.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Block.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" @@ -178,7 +180,9 @@ static SmallVector convertToWorkload(OpBuilder &b, Location loc, /// linalg.init_tensor operations. static bool isRootOp(Operation *op) { - return isa(op); + return isa(op); } static bool isAlwaysClonedIntoDispatchOp(Operation *op) { @@ -524,6 +528,14 @@ struct TileAndDistributeOnTensorsPattern SmallVector count = llvm::to_vector<4>( llvm::map_range(linalgOp.createLoopRanges(rewriter, loc), [](Range r) { return r.size; })); + // NOTE: Special treatment for convolution, which have more than 3 parallel + // dimensions. We want to ignore the batch dimension and tile along the + // next three. + // TODO(#5048): figure out a better way to avoid this special case. + if (isa(op)) { + count.erase(count.begin()); + } count.resize(getNumTilableLoops(op)); auto workload = convertToWorkload(rewriter, loc, count); @@ -770,10 +782,9 @@ void DispatchLinalgOnTensorsPass::runOnOperation() { static linalg::LinalgLoopDistributionOptions workgroupDistributionOptions = { [](OpBuilder &builder, Location loc, ArrayRef parallelLoopRanges) { auto numParallelDims = parallelLoopRanges.size(); + SmallVector procInfo(numParallelDims); - for (size_t dim = 0; - dim < std::min(numParallelDims, kNumMaxParallelDims); - ++dim) { + for (size_t dim = 0; dim < numParallelDims; ++dim) { procInfo[numParallelDims - dim - 1] = { buildFlowWorkgroupInfoOp(builder, dim), @@ -787,21 +798,34 @@ void DispatchLinalgOnTensorsPass::runOnOperation() { auto tileSizeFn = [&](OpBuilder &builder, Operation *op) -> SmallVector { + auto numParallelDims = getNumOuterParallelLoops(cast(op)); auto numTiledLoops = getNumTilableLoops(cast(op)); - SmallVector useTileSizes(numTiledLoops); + + // Default to zero to skip tiling. + auto zero = builder.create(op->getLoc(), 0); + SmallVector useTileSizes(numParallelDims, zero); + if (!clLinalgOnTensorsTileSizes.empty()) { SmallVector tileSizes(clLinalgOnTensorsTileSizes.begin(), clLinalgOnTensorsTileSizes.end()); - useTileSizes.resize(std::min(tileSizes.size(), numTiledLoops)); + useTileSizes.resize(std::min(tileSizes.size(), numParallelDims)); return llvm::to_vector<4>(llvm::map_range( ArrayRef(tileSizes).take_front( - std::min(tileSizes.size(), numTiledLoops)), + std::min(tileSizes.size(), numParallelDims)), [&](int64_t t) -> Value { return builder.create(op->getLoc(), t); })); } + + // NOTE: Special treatment for convolution, which have more than 3 + // parallel dimensions. We want to ignore the batch dimension and tile + // along the next three. That means setting the first position to zero. + // TODO(#5048): figure out a better way to avoid this special case. + bool isConvOp = isa(op); + for (size_t dim = 0; dim < numTiledLoops; ++dim) { - useTileSizes[numTiledLoops - dim - 1] = + useTileSizes[(isConvOp ? numParallelDims : numTiledLoops) - dim - 1] = buildFlowWorkgroupInfoOp(builder, dim); } return useTileSizes; @@ -858,6 +882,19 @@ void DispatchLinalgOnTensorsPass::runOnOperation() { return signalPassFailure(); } + // Run necessary canonicalization patterns before destructive updates. + { + OwningRewritePatternList patterns; + // This is needed because tiling and distribution may create + // subtensor_insert ops whose source operands come from tensor.cast ops. + // Those tensor.cast ops cast tensors into a more dynamic shape, in order + // to guarantee type match during transformation. Later in destructive + // update subtensor_insert ops will be turned into flow dispatch output + // store ops. + SubTensorInsertOp::getCanonicalizationPatterns(patterns, context); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + } + // Rewrite destructive updates and ensure no remaining store remains to the // full output. if (funcOp diff --git a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir index 70fe28f58566..ac5225c94c90 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir @@ -194,3 +194,33 @@ func @tensor5(%A: tensor, %B: tensor, %C: tensor) // CHECK: return %[[D]], %[[origCC]] return %D, %CC: tensor, tensor } + +func @conv2d(%input: tensor<1x225x225x16xf32>, %filter: tensor<3x3x16x32xf32>) -> tensor<1x112x112x32xf32> { + %0 = linalg.init_tensor [1, 112, 112, 32] : tensor<1x112x112x32xf32> + %cst = constant 0.000000e+00 : f32 + %1 = linalg.fill(%0, %cst) : tensor<1x112x112x32xf32>, f32 -> tensor<1x112x112x32xf32> + %2 = linalg.conv_2d_input_nhwc_filter_hwcf + {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} + ins(%input, %filter : tensor<1x225x225x16xf32>, tensor<3x3x16x32xf32>) + outs(%1 : tensor<1x112x112x32xf32>) + -> tensor<1x112x112x32xf32> + return %2 : tensor<1x112x112x32xf32> +} + +// CHECK-LABEL: func @conv2d +// CHECK: scf.for +// CHECK: scf.for +// CHECK: linalg.conv_2d_input_nhwc_filter_hwcf + +func @depthwise_conv2d(%input: tensor<1x113x113x96xf32>, %filter: tensor<3x3x96xf32>) -> tensor<1x56x56x96xf32> { + %cst = constant 0.000000e+00 : f32 + %1 = linalg.init_tensor [1, 56, 56, 96] : tensor<1x56x56x96xf32> + %2 = linalg.fill(%1, %cst) : tensor<1x56x56x96xf32>, f32 -> tensor<1x56x56x96xf32> + %4 = linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<2> : tensor<2xi64>} ins(%input, %filter : tensor<1x113x113x96xf32>, tensor<3x3x96xf32>) outs(%2 : tensor<1x56x56x96xf32>) -> tensor<1x56x56x96xf32> + return %4 : tensor<1x56x56x96xf32> +} + +// CHECK-LABEL: func @depthwise_conv2d +// CHECK: scf.for +// CHECK: scf.for +// CHECK: linalg.depthwise_conv_2d_input_nhwc_filter_hwc diff --git a/iree/test/e2e/vulkan_specific/BUILD b/iree/test/e2e/vulkan_specific/BUILD index 689361b25a31..876cc74c4727 100644 --- a/iree/test/e2e/vulkan_specific/BUILD +++ b/iree/test/e2e/vulkan_specific/BUILD @@ -74,6 +74,8 @@ iree_check_single_backend_test_suite( "vectorized_conv.mlir", ], compiler_flags = [ + "-iree-flow-dispatch-linalg-on-tensors", + "-iree-codegen-spirv-experimental-linalg-on-tensors", "-iree-spirv-enable-vectorization", "-iree-vulkan-target-triple=valhall-g77-unknown-android10", ], diff --git a/iree/test/e2e/vulkan_specific/CMakeLists.txt b/iree/test/e2e/vulkan_specific/CMakeLists.txt index f476cdaaf052..c5ec28c9189e 100644 --- a/iree/test/e2e/vulkan_specific/CMakeLists.txt +++ b/iree/test/e2e/vulkan_specific/CMakeLists.txt @@ -66,6 +66,8 @@ iree_check_single_backend_test_suite( DRIVER "vulkan" COMPILER_FLAGS + "-iree-flow-dispatch-linalg-on-tensors" + "-iree-codegen-spirv-experimental-linalg-on-tensors" "-iree-spirv-enable-vectorization" "-iree-vulkan-target-triple=valhall-g77-unknown-android10" ) From 313c09a4ea283a0902dc5ca878915344c964b058 Mon Sep 17 00:00:00 2001 From: Geoffrey Martin-Noble Date: Thu, 11 Mar 2021 10:58:35 -0800 Subject: [PATCH 11/17] Fix file lists in enforce_glob (#5067) These come from midair collision between https://github.com/google/iree/pull/5063 and https://github.com/google/iree/pull/4881 --- iree/compiler/Dialect/Flow/IR/BUILD | 1 + iree/compiler/Dialect/Flow/Transforms/test/BUILD | 2 -- iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt | 2 -- iree/compiler/Dialect/Shape/IR/BUILD | 1 + 4 files changed, 2 insertions(+), 4 deletions(-) diff --git a/iree/compiler/Dialect/Flow/IR/BUILD b/iree/compiler/Dialect/Flow/IR/BUILD index 940064be7e4d..9668dd208784 100644 --- a/iree/compiler/Dialect/Flow/IR/BUILD +++ b/iree/compiler/Dialect/Flow/IR/BUILD @@ -27,6 +27,7 @@ filegroup( srcs = enforce_glob( [ "FlowBase.td", + "FlowInterfaces.td", "FlowOps.td", ], include = ["*.td"], diff --git a/iree/compiler/Dialect/Flow/Transforms/test/BUILD b/iree/compiler/Dialect/Flow/Transforms/test/BUILD index 048f9f65da28..3903093842b6 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/BUILD +++ b/iree/compiler/Dialect/Flow/Transforms/test/BUILD @@ -48,9 +48,7 @@ iree_lit_test_suite( "materialize_exported_reflection.mlir", "merge_exported_reflection.mlir", "outline_dispatch_regions2.mlir", - "outline_dispatch_regions_ranked_dynamic.mlir", "outline_large_constants.mlir", - "rematerialize_dispatch_constants.mlir", "strip_and_splat_constant_variables.mlir", "transformation.mlir", ], diff --git a/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt b/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt index f1cf8b711124..1b9ea4dca2f7 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt +++ b/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt @@ -37,9 +37,7 @@ iree_lit_test_suite( "materialize_exported_reflection.mlir" "merge_exported_reflection.mlir" "outline_dispatch_regions2.mlir" - "outline_dispatch_regions_ranked_dynamic.mlir" "outline_large_constants.mlir" - "rematerialize_dispatch_constants.mlir" "strip_and_splat_constant_variables.mlir" "transformation.mlir" DATA diff --git a/iree/compiler/Dialect/Shape/IR/BUILD b/iree/compiler/Dialect/Shape/IR/BUILD index 2a0ff9592270..59d237540531 100644 --- a/iree/compiler/Dialect/Shape/IR/BUILD +++ b/iree/compiler/Dialect/Shape/IR/BUILD @@ -29,6 +29,7 @@ filegroup( srcs = enforce_glob( [ "ShapeBase.td", + "ShapeInterfaces.td", "ShapeOps.td", ], include = ["*.td"], From f8fb63238966358e59c8c506580e6f8b560a02e3 Mon Sep 17 00:00:00 2001 From: Geoffrey Martin-Noble Date: Thu, 11 Mar 2021 11:07:51 -0800 Subject: [PATCH 12/17] Fix enforce_glob midair collision (#5069) Midair collision between https://github.com/google/iree/pull/4999 and https://github.com/google/iree/pull/5063 --- iree/compiler/Conversion/LinalgToSPIRV/test/BUILD | 3 +++ iree/compiler/Conversion/LinalgToSPIRV/test/CMakeLists.txt | 3 +++ 2 files changed, 6 insertions(+) diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/BUILD b/iree/compiler/Conversion/LinalgToSPIRV/test/BUILD index 8ed1c59a3487..761d17431f4a 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/BUILD +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/BUILD @@ -28,6 +28,7 @@ iree_lit_test_suite( srcs = enforce_glob( [ "batch_matmul_vectorization.mlir", + "concretize_tile_among_workgroups.mlir", "convert_to_gpu.mlir", "convert_to_spirv.mlir", "dead_alloc.mlir", @@ -43,6 +44,8 @@ iree_lit_test_suite( "pipeline_test.mlir", "pipeline_test_cooperative_mat.mlir", "split_dispatch_function.mlir", + "tile_and_vectorize_conv.mlir", + "tile_and_vectorize_matmul.mlir", "vector_to_gpu.mlir", "workgroup_memory_promotion.mlir", ], diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/CMakeLists.txt b/iree/compiler/Conversion/LinalgToSPIRV/test/CMakeLists.txt index 30df4bd36c9e..9c7c5001e05f 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/CMakeLists.txt +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/CMakeLists.txt @@ -15,6 +15,7 @@ iree_lit_test_suite( lit SRCS "batch_matmul_vectorization.mlir" + "concretize_tile_among_workgroups.mlir" "convert_to_gpu.mlir" "convert_to_spirv.mlir" "dead_alloc.mlir" @@ -30,6 +31,8 @@ iree_lit_test_suite( "pipeline_test.mlir" "pipeline_test_cooperative_mat.mlir" "split_dispatch_function.mlir" + "tile_and_vectorize_conv.mlir" + "tile_and_vectorize_matmul.mlir" "vector_to_gpu.mlir" "workgroup_memory_promotion.mlir" DATA From c05dd63f9f90c24b54121a2875aafe5f02446adf Mon Sep 17 00:00:00 2001 From: Thomas Date: Thu, 11 Mar 2021 11:37:48 -0800 Subject: [PATCH 13/17] Update dockers to include CUDA lib (#5055) --- build_tools/docker/cmake-bazel-frontends-nvidia/Dockerfile | 4 +++- build_tools/docker/cmake-python-nvidia/Dockerfile | 4 +++- build_tools/docker/prod_digests.txt | 4 ++-- .../gcp_ubuntu/cmake-bazel/linux/x86-turing/build_kokoro.sh | 2 +- .../kokoro/gcp_ubuntu/cmake/linux/x86-turing/build_kokoro.sh | 2 +- 5 files changed, 10 insertions(+), 6 deletions(-) diff --git a/build_tools/docker/cmake-bazel-frontends-nvidia/Dockerfile b/build_tools/docker/cmake-bazel-frontends-nvidia/Dockerfile index 8ef012a21356..7251f8ba64d4 100644 --- a/build_tools/docker/cmake-bazel-frontends-nvidia/Dockerfile +++ b/build_tools/docker/cmake-bazel-frontends-nvidia/Dockerfile @@ -26,4 +26,6 @@ FROM gcr.io/iree-oss/cmake-bazel-frontends-vulkan@sha256:3656d9c3a08770f8371ad3a6f777b979c954656dae1cdbdfadfccdd7ab713f87 AS final RUN apt-get update \ - && DEBIAN_FRONTEND=noninteractive apt-get install -y libnvidia-gl-460=460.32.03-0ubuntu0.18.04.1 + && DEBIAN_FRONTEND=noninteractive apt-get install -y \ + libnvidia-gl-460=460.39-0ubuntu0.18.04.1 \ + libnvidia-compute-460=460.39-0ubuntu0.18.04.1 diff --git a/build_tools/docker/cmake-python-nvidia/Dockerfile b/build_tools/docker/cmake-python-nvidia/Dockerfile index ef2803bca003..488e0de86859 100644 --- a/build_tools/docker/cmake-python-nvidia/Dockerfile +++ b/build_tools/docker/cmake-python-nvidia/Dockerfile @@ -28,4 +28,6 @@ FROM gcr.io/iree-oss/cmake-python-vulkan@sha256:6722f69c6300749f6bd4b141fc653244990381a6b0111f9c361061adcd65c07c AS final RUN apt-get update \ - && DEBIAN_FRONTEND=noninteractive apt-get install -y libnvidia-gl-460=460.32.03-0ubuntu0.18.04.1 + && DEBIAN_FRONTEND=noninteractive apt-get install -y \ + libnvidia-gl-460=460.39-0ubuntu0.18.04.1 \ + libnvidia-compute-460=460.39-0ubuntu0.18.04.1 diff --git a/build_tools/docker/prod_digests.txt b/build_tools/docker/prod_digests.txt index d119c1817b32..6267a6feafd9 100644 --- a/build_tools/docker/prod_digests.txt +++ b/build_tools/docker/prod_digests.txt @@ -9,8 +9,8 @@ gcr.io/iree-oss/vulkan@sha256:5812ee64806a7f3df0739ccf0930c27cabce346901488eceb1 gcr.io/iree-oss/rbe-toolchain@sha256:d69c260b98a97ad430d34c4591fb2399e00888750f5d47ede00c1e6f3e774e5a gcr.io/iree-oss/cmake-python-vulkan@sha256:6722f69c6300749f6bd4b141fc653244990381a6b0111f9c361061adcd65c07c gcr.io/iree-oss/cmake-python-swiftshader@sha256:0be2b0c735a038365e7cad31f6b440805dd4e231e166a114ef22914a5469cbc8 -gcr.io/iree-oss/cmake-python-nvidia@sha256:2d823e6fc528d0da5296ae35c0aad8008306381e3ac923679c4052adedd029d7 +gcr.io/iree-oss/cmake-python-nvidia@sha256:0c931cac303791af85c5a717b418997cac9f3319717f59f7a70ac777edfa7b33 gcr.io/iree-oss/cmake-bazel-frontends@sha256:2a9c65cea8b061696c8217b78e9718ceb804b9245bb3a661ff8a56f832aeeb0a gcr.io/iree-oss/cmake-bazel-frontends-vulkan@sha256:3656d9c3a08770f8371ad3a6f777b979c954656dae1cdbdfadfccdd7ab713f87 -gcr.io/iree-oss/cmake-bazel-frontends-nvidia@sha256:e71d77c2e8b886e99ffce2f50c73c5ae5e316bde23b54f0e6090bb04ce2d1bdf +gcr.io/iree-oss/cmake-bazel-frontends-nvidia@sha256:bc1502af0679d301feb27477c2476006282b5f956cecb650b2716c7a2876e722 gcr.io/iree-oss/cmake-bazel-frontends-swiftshader@sha256:1466d3658f872a0675b3ac605c2e577def880c7334738d677d5ef2f1c1291646 diff --git a/build_tools/kokoro/gcp_ubuntu/cmake-bazel/linux/x86-turing/build_kokoro.sh b/build_tools/kokoro/gcp_ubuntu/cmake-bazel/linux/x86-turing/build_kokoro.sh index bd98aab86c00..440ea340446b 100755 --- a/build_tools/kokoro/gcp_ubuntu/cmake-bazel/linux/x86-turing/build_kokoro.sh +++ b/build_tools/kokoro/gcp_ubuntu/cmake-bazel/linux/x86-turing/build_kokoro.sh @@ -38,7 +38,7 @@ docker_setup docker run "${DOCKER_RUN_ARGS[@]?}" \ --gpus all \ - gcr.io/iree-oss/cmake-bazel-frontends-nvidia@sha256:e71d77c2e8b886e99ffce2f50c73c5ae5e316bde23b54f0e6090bb04ce2d1bdf \ + gcr.io/iree-oss/cmake-bazel-frontends-nvidia@sha256:bc1502af0679d301feb27477c2476006282b5f956cecb650b2716c7a2876e722 \ build_tools/kokoro/gcp_ubuntu/cmake-bazel/linux/x86-turing/build.sh # Kokoro will rsync this entire directory back to the executor orchestrating the diff --git a/build_tools/kokoro/gcp_ubuntu/cmake/linux/x86-turing/build_kokoro.sh b/build_tools/kokoro/gcp_ubuntu/cmake/linux/x86-turing/build_kokoro.sh index 0b20fd6d1140..46513ed81c20 100755 --- a/build_tools/kokoro/gcp_ubuntu/cmake/linux/x86-turing/build_kokoro.sh +++ b/build_tools/kokoro/gcp_ubuntu/cmake/linux/x86-turing/build_kokoro.sh @@ -38,7 +38,7 @@ docker_setup docker run "${DOCKER_RUN_ARGS[@]?}" \ --gpus all \ - gcr.io/iree-oss/cmake-python-nvidia@sha256:2d823e6fc528d0da5296ae35c0aad8008306381e3ac923679c4052adedd029d7 \ + gcr.io/iree-oss/cmake-python-nvidia@sha256:0c931cac303791af85c5a717b418997cac9f3319717f59f7a70ac777edfa7b33 \ build_tools/kokoro/gcp_ubuntu/cmake/linux/x86-turing/build.sh # Kokoro will rsync this entire directory back to the executor orchestrating the From 9467a90339b2f466d5a2666aca8e31b3f84906f5 Mon Sep 17 00:00:00 2001 From: Geoffrey Martin-Noble Date: Thu, 11 Mar 2021 11:57:11 -0800 Subject: [PATCH 14/17] Make error message in enforce_glob more copy-pasteable (#5071) --- build_tools/bazel/enforce_glob.bzl | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/build_tools/bazel/enforce_glob.bzl b/build_tools/bazel/enforce_glob.bzl index 58cb4541087f..1ee7c62284ba 100644 --- a/build_tools/bazel/enforce_glob.bzl +++ b/build_tools/bazel/enforce_glob.bzl @@ -42,9 +42,18 @@ def enforce_glob(files, **kwargs): result_dict = {k: None for k in files} missing = [k for k in glob_result if k not in files] extra = [k for k in files if k not in glob_result] + expected_formatted = "\n".join(['"{}",'.format(file) for file in glob_result]) fail(("Error in enforce_glob." + "\nExpected {}." + "\nGot {}." + "\nMissing {}." + - "\nExtra {}").format(glob_result, files, missing, extra)) + "\nExtra {}" + + "\nPaste this into the first enforce_glob argument:" + + "\n{}").format( + glob_result, + files, + missing, + extra, + expected_formatted, + )) return files From 54749cedb82724b7ecadecd2101dbef836ff2093 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Tue, 2 Mar 2021 16:17:08 -0800 Subject: [PATCH 15/17] Adding TiedOpInterface and wiring it through the Flow dialect. This allows for results of operations to be tied back to their operands in storage but not in time. This allows for in-place operations to be defined on tensors that carry enough metadata to be able to correctly form streams, materialize HAL interfaces, and allocate buffers. Example: ```mlir %t = flow.dispatch @foo[...](%input) : (tensor<4xf32>) -> %input ``` This syntax also combines with the shape-carrying op interface to make it possible to also indicate that an input and a result share type and shape information: ```mlir %t = flow.dispatch @foo[...](%input) : (tensor{%dim}) -> %input ``` which is effectively: ```mlir %t = flow.dispatch @foo[...](%input) : (tensor{%dim}) -> tensor{%dim} ``` but with the extra bit that result 0 is tied to operand 0. Here the result %t of the dispatch aliases the storage for %input, making %input a read-write/mutable binding in the resulting HAL executable. %t is a distinct SSA value from %input, though, and represents the value of the storage backing %input after the dispatch has completed. By keeping the SSA use-def chains correct with respect to time they are still meaningful for analysi2As and nothing at this level (and the beginning of the HAL transformations) needs to perform alias analysis, while still giving us all of the information required to induce aliasing during later allocation passes. --- .../Conversion/Common/LinalgBufferizePass.cpp | 16 +- .../Common/test/linalg_bufferize.mlir | 166 ++++----- .../LinalgToLLVM/test/linalg_vectorize.mlir | 24 +- .../materialize_launch_configuration.mlir | 8 +- .../test/matmul_vectorization.mlir | 8 +- .../test/tile_and_distribute.mlir | 8 +- .../LinalgToNVVM/test/pipeline_test.mlir | 16 +- .../test/batch_matmul_vectorization.mlir | 8 +- .../concretize_tile_among_workgroups.mlir | 4 +- .../LinalgToSPIRV/test/convert_to_gpu.mlir | 28 +- .../test/linalg_tile_and_fuse.mlir | 26 +- .../materialize_launch_configuration.mlir | 4 +- .../materialize_launch_configuration2.mlir | 4 +- .../test/matmul_fused_vectorization.mlir | 4 +- .../test/matmul_vectorization.mlir | 6 +- .../LinalgToSPIRV/test/pipeline_test.mlir | 20 +- .../test/pipeline_test_cooperative_mat.mlir | 4 +- .../test/split_dispatch_function.mlir | 52 +-- .../test/tile_and_vectorize_conv.mlir | 4 +- .../test/tile_and_vectorize_matmul.mlir | 2 +- .../test/workgroup_memory_promotion.mlir | 8 +- iree/compiler/Dialect/Flow/IR/FlowBase.td | 55 ++- iree/compiler/Dialect/Flow/IR/FlowDialect.cpp | 12 +- .../Dialect/Flow/IR/FlowOpFolders.cpp | 142 +++++++- iree/compiler/Dialect/Flow/IR/FlowOps.cpp | 341 +++++++++++++++--- iree/compiler/Dialect/Flow/IR/FlowOps.td | 66 +++- iree/compiler/Dialect/Flow/IR/FlowTypes.cpp | 105 +++--- iree/compiler/Dialect/Flow/IR/FlowTypes.h | 112 ++---- .../Dialect/Flow/IR/test/dispatch_ops.mlir | 32 +- .../Flow/IR/test/dispatch_workgroups.mlir | 56 ++- .../IR/test/dispatch_workgroups_folding.mlir | 26 +- .../Dialect/Flow/IR/test/stream_folding.mlir | 63 +++- .../Dialect/Flow/IR/test/tensor_folding.mlir | 9 - iree/compiler/Dialect/Flow/IR/test/types.mlir | 40 +- .../Flow/Transforms/CreateBenchmarkFuncs.cpp | 3 +- .../Transforms/DestructiveUpdateUtils.cpp | 20 +- .../Flow/Transforms/DestructiveUpdateUtils.h | 25 +- .../Transforms/DispatchLinalgOnTensors.cpp | 21 +- .../Dialect/Flow/Transforms/FormStreams.cpp | 3 +- .../Transforms/OutlineDispatchRegions.cpp | 2 +- .../Transforms/OutlineDispatchRegions2.cpp | 14 +- .../test/create_benchmark_funcs.mlir | 8 +- .../test/deduplicate_executables.mlir | 30 +- .../test/dispatch_linalg_on_tensors.mlir | 34 +- .../dispatch_linalg_on_tensors_dynamic.mlir | 88 ++--- .../Flow/Transforms/test/form_streams.mlir | 75 ++-- .../test/hoist_unstreamable_ops.mlir | 12 +- .../test/inject_dispatch_tracing.mlir | 6 +- .../test/outline_dispatch_regions2.mlir | 72 ++-- .../Flow/Transforms/test/transformation.mlir | 35 +- .../Conversion/FlowToHAL/ConvertStreamOps.cpp | 116 +++++- .../Conversion/FlowToHAL/test/stream_ops.mlir | 9 +- .../HAL/Target/CUDA/test/smoketest.mlir | 8 +- .../HAL/Transforms/MaterializeInterfaces.cpp | 47 ++- .../Dialect/HAL/Transforms/Passes.cpp | 3 + .../test/materialize_interfaces.mlir | 54 +-- iree/compiler/Dialect/IREE/IR/BUILD | 18 + iree/compiler/Dialect/IREE/IR/CMakeLists.txt | 12 + iree/compiler/Dialect/IREE/IR/IREEBase.td | 5 + .../Dialect/IREE/IR/IREEInterfaces.td | 166 +++++++++ iree/compiler/Dialect/IREE/IR/IREETypes.cpp | 119 ++++++ iree/compiler/Dialect/IREE/IR/IREETypes.h | 23 ++ iree/test/e2e/models/unidirectional_lstm.mlir | 6 +- 63 files changed, 1700 insertions(+), 813 deletions(-) create mode 100644 iree/compiler/Dialect/IREE/IR/IREEInterfaces.td diff --git a/iree/compiler/Conversion/Common/LinalgBufferizePass.cpp b/iree/compiler/Conversion/Common/LinalgBufferizePass.cpp index 0a61ea2c5d49..3f1f3b0e453b 100644 --- a/iree/compiler/Conversion/Common/LinalgBufferizePass.cpp +++ b/iree/compiler/Conversion/Common/LinalgBufferizePass.cpp @@ -407,7 +407,7 @@ static SmallVector extractFromI64ArrayAttr(Attribute attr) { } LogicalResult convertInterfaceLoadTensorOp( - OpBuilder &b, IREE::Flow::DispatchInputLoadOp loadOp, + OpBuilder &b, IREE::Flow::DispatchTensorLoadOp loadOp, BlockAndValueMapping &bvm) { OpBuilder::InsertionGuard g(b); b.setInsertionPoint(loadOp); @@ -449,7 +449,7 @@ static Operation *getInsertionPointForReplacementStoreOp( /// LinalgOp, create the subview operation that can be used by the op itself to /// store the result into directly. This avoids an extra allocation + copies. LogicalResult preProcessInterfaceStoreTensorOp( - OpBuilder &b, IREE::Flow::DispatchOutputStoreOp storeOp, + OpBuilder &b, IREE::Flow::DispatchTensorStoreOp storeOp, BlockAndValueMapping &bvm) { // Find the insertion point for the subview. SmallVector operandsOfSubviewOp; @@ -491,7 +491,7 @@ LogicalResult preProcessLinalgOps(OpBuilder &b, linalg::LinalgOp op, } LogicalResult convertInterfaceStoreTensorOp( - OpBuilder &b, IREE::Flow::DispatchOutputStoreOp storeOp, + OpBuilder &b, IREE::Flow::DispatchTensorStoreOp storeOp, BlockAndValueMapping &bvm) { OpBuilder::InsertionGuard g(b); b.setInsertionPoint(storeOp); @@ -570,7 +570,7 @@ void LinalgBufferizePass::runOnFunction() { transferShapeOpsToMemref(b, op.getResult(), baseBuffer.getResult(), bvm); }); if (funcOp - .walk([&](IREE::Flow::DispatchOutputStoreOp op) -> WalkResult { + .walk([&](IREE::Flow::DispatchTensorStoreOp op) -> WalkResult { return preProcessInterfaceStoreTensorOp(b, op, bvm); }) .wasInterrupted()) { @@ -596,12 +596,12 @@ void LinalgBufferizePass::runOnFunction() { auto conversionDispatch = [&](Operation *op) -> WalkResult { return TypeSwitch(op) - .Case( - [&](IREE::Flow::DispatchInputLoadOp loadOp) { + .Case( + [&](IREE::Flow::DispatchTensorLoadOp loadOp) { return convertInterfaceLoadTensorOp(b, loadOp, bvm); }) - .Case( - [&](IREE::Flow::DispatchOutputStoreOp storeOp) { + .Case( + [&](IREE::Flow::DispatchTensorStoreOp storeOp) { return convertInterfaceStoreTensorOp(b, storeOp, bvm); }) .Case([&](linalg::LinalgOp linalgOp) { diff --git a/iree/compiler/Conversion/Common/test/linalg_bufferize.mlir b/iree/compiler/Conversion/Common/test/linalg_bufferize.mlir index 7257acbd8f2b..e5f547306003 100644 --- a/iree/compiler/Conversion/Common/test/linalg_bufferize.mlir +++ b/iree/compiler/Conversion/Common/test/linalg_bufferize.mlir @@ -6,20 +6,20 @@ func @tile_from_tensor_load() { %c4 = constant 4 : index %c1 = constant 1 : index %c3 = constant 3 : index - %0 = hal.interface.binding.subspan @legacy_io::@TENSOR_LHS[%c0] : !flow.dispatch.input - %1 = hal.interface.binding.subspan @legacy_io::@TENSOR_RHS[%c0] : !flow.dispatch.input - %2 = hal.interface.binding.subspan @legacy_io::@TENSOR_INIT[%c0] : !flow.dispatch.input - %3 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.output + %0 = hal.interface.binding.subspan @legacy_io::@TENSOR_LHS[%c0] : !flow.dispatch.tensor + %1 = hal.interface.binding.subspan @legacy_io::@TENSOR_RHS[%c0] : !flow.dispatch.tensor + %2 = hal.interface.binding.subspan @legacy_io::@TENSOR_INIT[%c0] : !flow.dispatch.tensor + %3 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.tensor %4 = hal.interface.workgroup.id[0] : index %5 = hal.interface.workgroup.id[1] : index scf.for %arg0 = %5 to %c2 step %c2 { scf.for %arg1 = %4 to %c4 step %c4 { - %6 = flow.dispatch.input.load %0, offsets = [%arg0, %c0], sizes = [%c1, %c3], strides = [%c1, %c1] : !flow.dispatch.input -> tensor<1x3xf32> - %7 = flow.dispatch.input.load %1, offsets = [%c0, %arg1], sizes = [%c3, %c1], strides = [%c1, %c1] : !flow.dispatch.input -> tensor<3x1xf32> - %8 = flow.dispatch.input.load %2, offsets = [%arg0, %arg1], sizes = [%c1, %c1], strides = [%c1, %c1] : !flow.dispatch.input -> tensor<1x1xf32> + %6 = flow.dispatch.tensor.load %0, offsets = [%arg0, %c0], sizes = [%c1, %c3], strides = [%c1, %c1] : !flow.dispatch.tensor -> tensor<1x3xf32> + %7 = flow.dispatch.tensor.load %1, offsets = [%c0, %arg1], sizes = [%c3, %c1], strides = [%c1, %c1] : !flow.dispatch.tensor -> tensor<3x1xf32> + %8 = flow.dispatch.tensor.load %2, offsets = [%arg0, %arg1], sizes = [%c1, %c1], strides = [%c1, %c1] : !flow.dispatch.tensor -> tensor<1x1xf32> %9 = linalg.matmul ins(%6, %7 : tensor<1x3xf32>, tensor<3x1xf32>) outs(%8 : tensor<1x1xf32>) -> tensor<1x1xf32> - flow.dispatch.output.store %9, %3, offsets = [%arg0, %arg1], sizes = [%c1, %c1], strides = [%c1, %c1] : tensor<1x1xf32> -> !flow.dispatch.output + flow.dispatch.tensor.store %9, %3, offsets = [%arg0, %arg1], sizes = [%c1, %c1], strides = [%c1, %c1] : tensor<1x1xf32> -> !flow.dispatch.tensor } } return @@ -56,25 +56,25 @@ func @tile_from_pointwise_lhs() { %c4 = constant 4 : index %c1 = constant 1 : index %c3 = constant 3 : index - %0 = hal.interface.binding.subspan @legacy_io::@TENSOR_LHS[%c0] : !flow.dispatch.input - %1 = hal.interface.binding.subspan @legacy_io::@TENSOR_RHS[%c0] : !flow.dispatch.input - %2 = hal.interface.binding.subspan @legacy_io::@TENSOR_INIT[%c0] : !flow.dispatch.input - %3 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.output + %0 = hal.interface.binding.subspan @legacy_io::@TENSOR_LHS[%c0] : !flow.dispatch.tensor + %1 = hal.interface.binding.subspan @legacy_io::@TENSOR_RHS[%c0] : !flow.dispatch.tensor + %2 = hal.interface.binding.subspan @legacy_io::@TENSOR_INIT[%c0] : !flow.dispatch.tensor + %3 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.tensor %4 = hal.interface.workgroup.id[0] : index %5 = hal.interface.workgroup.id[1] : index scf.for %arg0 = %5 to %c2 step %c2 { scf.for %arg1 = %4 to %c4 step %c4 { - %6 = flow.dispatch.input.load %0, offsets = [%arg0, %c0], sizes = [%c1, %c3], strides = [%c1, %c1] : !flow.dispatch.input -> tensor<1x3xf32> - %7 = flow.dispatch.input.load %1, offsets = [%c0, %arg1], sizes = [%c3, %c1], strides = [%c1, %c1] : !flow.dispatch.input -> tensor<3x1xf32> + %6 = flow.dispatch.tensor.load %0, offsets = [%arg0, %c0], sizes = [%c1, %c3], strides = [%c1, %c1] : !flow.dispatch.tensor -> tensor<1x3xf32> + %7 = flow.dispatch.tensor.load %1, offsets = [%c0, %arg1], sizes = [%c3, %c1], strides = [%c1, %c1] : !flow.dispatch.tensor -> tensor<3x1xf32> %shape = linalg.init_tensor [1, 3] : tensor<1x3xf32> %8 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%6 : tensor<1x3xf32>) outs(%shape : tensor<1x3xf32>) { ^bb0(%arg2: f32, %s: f32): // no predecessors linalg.yield %arg2 : f32 } -> tensor<1x3xf32> - %9 = flow.dispatch.input.load %2, offsets = [%arg0, %arg1], sizes = [%c1, %c1], strides = [%c1, %c1] : !flow.dispatch.input -> tensor<1x1xf32> + %9 = flow.dispatch.tensor.load %2, offsets = [%arg0, %arg1], sizes = [%c1, %c1], strides = [%c1, %c1] : !flow.dispatch.tensor -> tensor<1x1xf32> %10 = linalg.matmul ins(%8, %7 : tensor<1x3xf32>, tensor<3x1xf32>) outs(%9 : tensor<1x1xf32>) -> tensor<1x1xf32> - flow.dispatch.output.store %10, %3, offsets = [%arg0, %arg1], sizes = [%c1, %c1], strides = [%c1, %c1] : tensor<1x1xf32> -> !flow.dispatch.output + flow.dispatch.tensor.store %10, %3, offsets = [%arg0, %arg1], sizes = [%c1, %c1], strides = [%c1, %c1] : tensor<1x1xf32> -> !flow.dispatch.tensor } } return @@ -115,17 +115,17 @@ func @tile_from_pointwise_outs() { %c4 = constant 4 : index %c1 = constant 1 : index %c3 = constant 3 : index - %0 = hal.interface.binding.subspan @legacy_io::@TENSOR_LHS[%c0] : !flow.dispatch.input - %1 = hal.interface.binding.subspan @legacy_io::@TENSOR_RHS[%c0] : !flow.dispatch.input - %2 = hal.interface.binding.subspan @legacy_io::@TENSOR_INIT[%c0] : !flow.dispatch.input - %3 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.output + %0 = hal.interface.binding.subspan @legacy_io::@TENSOR_LHS[%c0] : !flow.dispatch.tensor + %1 = hal.interface.binding.subspan @legacy_io::@TENSOR_RHS[%c0] : !flow.dispatch.tensor + %2 = hal.interface.binding.subspan @legacy_io::@TENSOR_INIT[%c0] : !flow.dispatch.tensor + %3 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.tensor %4 = hal.interface.workgroup.id[0] : index %5 = hal.interface.workgroup.id[1] : index scf.for %arg0 = %5 to %c2 step %c2 { scf.for %arg1 = %4 to %c4 step %c4 { - %6 = flow.dispatch.input.load %0, offsets = [%arg0, %c0], sizes = [%c1, %c3], strides = [%c1, %c1] : !flow.dispatch.input -> tensor<1x3xf32> - %7 = flow.dispatch.input.load %1, offsets = [%c0, %arg1], sizes = [%c3, %c1], strides = [%c1, %c1] : !flow.dispatch.input -> tensor<3x1xf32> - %8 = flow.dispatch.input.load %2, offsets = [%arg0, %arg1], sizes = [%c1, %c1], strides = [%c1, %c1] : !flow.dispatch.input -> tensor<1x1xf32> + %6 = flow.dispatch.tensor.load %0, offsets = [%arg0, %c0], sizes = [%c1, %c3], strides = [%c1, %c1] : !flow.dispatch.tensor -> tensor<1x3xf32> + %7 = flow.dispatch.tensor.load %1, offsets = [%c0, %arg1], sizes = [%c3, %c1], strides = [%c1, %c1] : !flow.dispatch.tensor -> tensor<3x1xf32> + %8 = flow.dispatch.tensor.load %2, offsets = [%arg0, %arg1], sizes = [%c1, %c1], strides = [%c1, %c1] : !flow.dispatch.tensor -> tensor<1x1xf32> %shape = linalg.init_tensor [1, 1] : tensor<1x1xf32> %9 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%8 : tensor<1x1xf32>) outs(%shape : tensor<1x1xf32>) { @@ -133,7 +133,7 @@ func @tile_from_pointwise_outs() { linalg.yield %arg2 : f32 } -> tensor<1x1xf32> %10 = linalg.matmul ins(%6, %7 : tensor<1x3xf32>, tensor<3x1xf32>) outs(%9 : tensor<1x1xf32>) -> tensor<1x1xf32> - flow.dispatch.output.store %10, %3, offsets = [%arg0, %arg1], sizes = [%c1, %c1], strides = [%c1, %c1] : tensor<1x1xf32> -> !flow.dispatch.output + flow.dispatch.tensor.store %10, %3, offsets = [%arg0, %arg1], sizes = [%c1, %c1], strides = [%c1, %c1] : tensor<1x1xf32> -> !flow.dispatch.tensor } } return @@ -167,10 +167,10 @@ hal.interface @legacy_io attributes {sym_visibility = "private"} { func @bufferize_dynamic() { %c0 = constant 0 : index %c1 = constant 1 : index - %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.input - %1 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : !flow.dispatch.input - %2 = hal.interface.binding.subspan @legacy_io::@arg2[%c0] : !flow.dispatch.input - %3 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.output + %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.tensor + %1 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : !flow.dispatch.tensor + %2 = hal.interface.binding.subspan @legacy_io::@arg2[%c0] : !flow.dispatch.tensor + %3 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.tensor %4 = hal.interface.load.constant offset = 0 : index %5 = hal.interface.load.constant offset = 1 : index %6 = hal.interface.load.constant offset = 2 : index @@ -180,13 +180,13 @@ func @bufferize_dynamic() { %10 = hal.interface.load.constant offset = 6 : index %11 = hal.interface.load.constant offset = 7 : index %12 = shapex.make_ranked_shape %4, %5 : (index, index) -> !shapex.ranked_shape<[?,?]> - %13 = flow.dispatch.tie_shape %0, %12 : (!flow.dispatch.input, !shapex.ranked_shape<[?,?]>) -> !flow.dispatch.input + %13 = flow.dispatch.tie_shape %0, %12 : (!flow.dispatch.tensor, !shapex.ranked_shape<[?,?]>) -> !flow.dispatch.tensor %14 = shapex.make_ranked_shape %6, %7 : (index, index) -> !shapex.ranked_shape<[?,?]> - %15 = flow.dispatch.tie_shape %1, %14 : (!flow.dispatch.input, !shapex.ranked_shape<[?,?]>) -> !flow.dispatch.input + %15 = flow.dispatch.tie_shape %1, %14 : (!flow.dispatch.tensor, !shapex.ranked_shape<[?,?]>) -> !flow.dispatch.tensor %16 = shapex.make_ranked_shape %8, %9 : (index, index) -> !shapex.ranked_shape<[?,?]> - %17 = flow.dispatch.tie_shape %2, %16 : (!flow.dispatch.input, !shapex.ranked_shape<[?,?]>) -> !flow.dispatch.input + %17 = flow.dispatch.tie_shape %2, %16 : (!flow.dispatch.tensor, !shapex.ranked_shape<[?,?]>) -> !flow.dispatch.tensor %18 = shapex.make_ranked_shape %10, %11 : (index, index) -> !shapex.ranked_shape<[?,?]> - %19 = flow.dispatch.tie_shape %3, %18 : (!flow.dispatch.output, !shapex.ranked_shape<[?,?]>) -> !flow.dispatch.output + %19 = flow.dispatch.tie_shape %3, %18 : (!flow.dispatch.tensor, !shapex.ranked_shape<[?,?]>) -> !flow.dispatch.tensor %workgroup_size_x = hal.interface.workgroup.size[0] : index %workgroup_size_y = hal.interface.workgroup.size[1] : index %workgroup_id_x = hal.interface.workgroup.id[0] : index @@ -200,14 +200,14 @@ func @bufferize_dynamic() { %23 = muli %workgroup_size_x, %workgroup_count_x : index scf.for %arg1 = %22 to %7 step %23 { %24 = affine.min affine_map<(d0)[s0, s1] -> (s1, -d0 + s0)>(%arg0)[%4, %workgroup_size_y] - %25 = flow.dispatch.input.load %13, offsets = [%arg0, %c0], sizes = [%24, %5], strides = [%c1, %c1] : !flow.dispatch.input -> tensor + %25 = flow.dispatch.tensor.load %13, offsets = [%arg0, %c0], sizes = [%24, %5], strides = [%c1, %c1] : !flow.dispatch.tensor -> tensor %26 = affine.min affine_map<(d0)[s0, s1] -> (s1, -d0 + s0)>(%arg1)[%7, %workgroup_size_x] - %27 = flow.dispatch.input.load %15, offsets = [%c0, %arg1], sizes = [%6, %26], strides = [%c1, %c1] : !flow.dispatch.input -> tensor + %27 = flow.dispatch.tensor.load %15, offsets = [%c0, %arg1], sizes = [%6, %26], strides = [%c1, %c1] : !flow.dispatch.tensor -> tensor %28 = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%arg0)[%workgroup_size_y, %8] %29 = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%arg1)[%workgroup_size_x, %9] - %30 = flow.dispatch.input.load %17, offsets = [%arg0, %arg1], sizes = [%28, %29], strides = [%c1, %c1] : !flow.dispatch.input -> tensor + %30 = flow.dispatch.tensor.load %17, offsets = [%arg0, %arg1], sizes = [%28, %29], strides = [%c1, %c1] : !flow.dispatch.tensor -> tensor %31 = linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%25, %27 : tensor, tensor) outs(%30 : tensor) -> tensor - flow.dispatch.output.store %31, %19, offsets = [%arg0, %arg1], sizes = [%28, %29], strides = [%c1, %c1] : tensor -> !flow.dispatch.output + flow.dispatch.tensor.store %31, %19, offsets = [%arg0, %arg1], sizes = [%28, %29], strides = [%c1, %c1] : tensor -> !flow.dispatch.tensor } } return @@ -271,13 +271,13 @@ hal.interface @legacy_io attributes {sym_visibility = "private"} { // %c0 = constant 0 : index // %c2 = constant 2 : index // %c1 = constant 1 : index -// %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.input<2x3xf32> -// %1 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : !flow.dispatch.input<3x4xf32> -// %2 = hal.interface.binding.subspan @legacy_io::@arg2[%c0] : !flow.dispatch.input<2x4xf32> -// %3 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.output<2x4xf32> -// %4 = flow.dispatch.input.load %0, offsets = [%c0, %c0], sizes = [%c1, %c3], strides = [%c1, %c1] : !flow.dispatch.input<2x3xf32> -> tensor<2x3xf32> -// %5 = flow.dispatch.input.load %1, offsets = [%c0, %c0], sizes = [%c3, %c1], strides = [%c1, %c1] : !flow.dispatch.input<3x4xf32> -> tensor<3x1xf32> -// %6 = flow.dispatch.input.load %2, offsets = [%c0, %c0], sizes = [%c1, %c1], strides = [%c1, %c1] : !flow.dispatch.input<2x4xf32> -> tensor<2x1xf32> +// %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.tensor +// %1 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : !flow.dispatch.tensor +// %2 = hal.interface.binding.subspan @legacy_io::@arg2[%c0] : !flow.dispatch.tensor +// %3 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.tensor +// %4 = flow.dispatch.tensor.load %0, offsets = [%c0, %c0], sizes = [%c1, %c3], strides = [%c1, %c1] : !flow.dispatch.tensor -> tensor<2x3xf32> +// %5 = flow.dispatch.tensor.load %1, offsets = [%c0, %c0], sizes = [%c3, %c1], strides = [%c1, %c1] : !flow.dispatch.tensor -> tensor<3x1xf32> +// %6 = flow.dispatch.tensor.load %2, offsets = [%c0, %c0], sizes = [%c1, %c1], strides = [%c1, %c1] : !flow.dispatch.tensor -> tensor<2x1xf32> // %7 = vector.transfer_read %4[%c0, %c0], %cst {masked = [false, false]} : tensor<2x3xf32>, vector<1x1xf32> // %8 = vector.transfer_read %4[%c0, %c1], %cst {masked = [false, false]} : tensor<2x3xf32>, vector<1x1xf32> // %9 = vector.transfer_read %4[%c0, %c2], %cst {masked = [false, false]} : tensor<2x3xf32>, vector<1x1xf32> @@ -297,7 +297,7 @@ hal.interface @legacy_io attributes {sym_visibility = "private"} { // %23 = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %12, %15, %22 : vector<1x1xf32>, vector<1x1xf32> into vector<1x1xf32> // %24 = vector.transfer_write %20, %6[%c0, %c0] {masked = [false, false]} : vector<1x1xf32>, tensor<2x1xf32> // %25 = vector.transfer_write %23, %24[%c1, %c0] {masked = [false, false]} : vector<1x1xf32>, tensor<2x1xf32> -// flow.dispatch.output.store %25, %3, offsets = [%c0, %c0], sizes = [%c1, %c1], strides = [%c1, %c1] : tensor<2x1xf32> -> !flow.dispatch.output<2x4xf32> +// flow.dispatch.tensor.store %25, %3, offsets = [%c0, %c0], sizes = [%c1, %c1], strides = [%c1, %c1] : tensor<2x1xf32> -> !flow.dispatch.tensor // return // } // } @@ -310,11 +310,11 @@ func @reshape_simple() { %c3 = constant 3 : index %c4 = constant 4 : index %c12 = constant 12 : index - %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.input<12xi32> - %1 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.output<3x4xi32> - %2 = flow.dispatch.input.load %0, offsets = [%c0], sizes = [%c12], strides = [%c1] : !flow.dispatch.input<12xi32> -> tensor<12xi32> + %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.tensor + %1 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.tensor + %2 = flow.dispatch.tensor.load %0, offsets = [%c0], sizes = [%c12], strides = [%c1] : !flow.dispatch.tensor -> tensor<12xi32> %3 = linalg.tensor_reshape %2 [affine_map<(d0, d1) -> (d0, d1)>] : tensor<12xi32> into tensor<3x4xi32> - flow.dispatch.output.store %3, %1, offsets = [%c0, %c0], sizes = [%c3, %c4], strides = [%c1, %c1] : tensor<3x4xi32> -> !flow.dispatch.output<3x4xi32> + flow.dispatch.tensor.store %3, %1, offsets = [%c0, %c0], sizes = [%c3, %c4], strides = [%c1, %c1] : tensor<3x4xi32> -> !flow.dispatch.tensor return } hal.interface @legacy_io attributes {sym_visibility = "private"} { @@ -338,9 +338,9 @@ func @reshape_fused_source() { %c3 = constant 3 : index %c4 = constant 4 : index %c12 = constant 12 : index - %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.input<12xi32> - %1 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.output<3x4xi32> - %2 = flow.dispatch.input.load %0, offsets = [%c0], sizes = [%c12], strides = [%c1] : !flow.dispatch.input<12xi32> -> tensor<12xi32> + %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.tensor + %1 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.tensor + %2 = flow.dispatch.tensor.load %0, offsets = [%c0], sizes = [%c12], strides = [%c1] : !flow.dispatch.tensor -> tensor<12xi32> %3 = linalg.tensor_reshape %2 [affine_map<(d0, d1) -> (d0, d1)>] : tensor<12xi32> into tensor<3x4xi32> %4 = linalg.init_tensor [3, 4] : tensor<3x4xi32> %5 = linalg.generic { @@ -351,7 +351,7 @@ func @reshape_fused_source() { %6 = addi %arg0, %arg0 : i32 linalg.yield %6 : i32 } -> tensor<3x4xi32> - flow.dispatch.output.store %5, %1, offsets = [%c0, %c0], sizes = [%c3, %c4], strides = [%c1, %c1] : tensor<3x4xi32> -> !flow.dispatch.output<3x4xi32> + flow.dispatch.tensor.store %5, %1, offsets = [%c0, %c0], sizes = [%c3, %c4], strides = [%c1, %c1] : tensor<3x4xi32> -> !flow.dispatch.tensor return } hal.interface @legacy_io attributes {sym_visibility = "private"} { @@ -378,10 +378,10 @@ func @reshape_fused_source_and_copyout() { %c3 = constant 3 : index %c4 = constant 4 : index %c12 = constant 12 : index - %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.input<12xi32> - %1 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.output<3x4xi32> - %2 = hal.interface.binding.subspan @legacy_io::@ret1[%c0] : !flow.dispatch.output<3x4xi32> - %3 = flow.dispatch.input.load %0, offsets = [%c0], sizes = [%c12], strides = [%c1] : !flow.dispatch.input<12xi32> -> tensor<12xi32> + %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.tensor + %1 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.tensor + %2 = hal.interface.binding.subspan @legacy_io::@ret1[%c0] : !flow.dispatch.tensor + %3 = flow.dispatch.tensor.load %0, offsets = [%c0], sizes = [%c12], strides = [%c1] : !flow.dispatch.tensor -> tensor<12xi32> %4 = linalg.tensor_reshape %3 [affine_map<(d0, d1) -> (d0, d1)>] : tensor<12xi32> into tensor<3x4xi32> %5 = linalg.init_tensor [3, 4] : tensor<3x4xi32> %6 = linalg.generic { @@ -392,8 +392,8 @@ func @reshape_fused_source_and_copyout() { %7 = addi %arg0, %arg0 : i32 linalg.yield %7 : i32 } -> tensor<3x4xi32> - flow.dispatch.output.store %6, %1, offsets = [%c0, %c0], sizes = [%c3, %c4], strides = [%c1, %c1] : tensor<3x4xi32> -> !flow.dispatch.output<3x4xi32> - flow.dispatch.output.store %4, %2, offsets = [%c0, %c0], sizes = [%c3, %c4], strides = [%c1, %c1] : tensor<3x4xi32> -> !flow.dispatch.output<3x4xi32> + flow.dispatch.tensor.store %6, %1, offsets = [%c0, %c0], sizes = [%c3, %c4], strides = [%c1, %c1] : tensor<3x4xi32> -> !flow.dispatch.tensor + flow.dispatch.tensor.store %4, %2, offsets = [%c0, %c0], sizes = [%c3, %c4], strides = [%c1, %c1] : tensor<3x4xi32> -> !flow.dispatch.tensor return } hal.interface @legacy_io attributes {sym_visibility = "private"} { @@ -425,9 +425,9 @@ func @reshape_fused_target() { %c3 = constant 3 : index %c4 = constant 4 : index %c12 = constant 12 : index - %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.input<3x4xi32> - %1 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.output<12xi32> - %2 = flow.dispatch.input.load %0, offsets = [%c0, %c0], sizes = [%c3, %c4], strides = [%c1, %c1] : !flow.dispatch.input<3x4xi32> -> tensor<3x4xi32> + %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.tensor + %1 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.tensor + %2 = flow.dispatch.tensor.load %0, offsets = [%c0, %c0], sizes = [%c3, %c4], strides = [%c1, %c1] : !flow.dispatch.tensor -> tensor<3x4xi32> %3 = linalg.init_tensor [3, 4] : tensor<3x4xi32> %4 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], @@ -438,7 +438,7 @@ func @reshape_fused_target() { linalg.yield %5 : i32 } -> tensor<3x4xi32> %5 = linalg.tensor_reshape %4 [affine_map<(d0, d1) -> (d0, d1)>] : tensor<3x4xi32> into tensor<12xi32> - flow.dispatch.output.store %5, %1, offsets = [%c0], sizes = [%c12], strides = [%c1] : tensor<12xi32> -> !flow.dispatch.output<12xi32> + flow.dispatch.tensor.store %5, %1, offsets = [%c0], sizes = [%c12], strides = [%c1] : tensor<12xi32> -> !flow.dispatch.tensor return } hal.interface @legacy_io attributes {sym_visibility = "private"} { @@ -467,10 +467,10 @@ func @dot_general_lowering() { %c0 = constant 0 : index %c2 = constant 2 : index %c1 = constant 1 : index - %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.input<1x1x2xf32> - %1 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : !flow.dispatch.input<2x3xf32> - %2 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.output<1x3xf32> - %3 = flow.dispatch.input.load %0 : !flow.dispatch.input<1x1x2xf32> -> tensor<1x1x2xf32> + %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.tensor + %1 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : !flow.dispatch.tensor + %2 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.tensor + %3 = flow.dispatch.tensor.load %0 : !flow.dispatch.tensor -> tensor<1x1x2xf32> %4 = linalg.tensor_reshape %3 [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>] : tensor<1x1x2xf32> into tensor<1x2xf32> %workgroup_size_x = hal.interface.workgroup.size[0] : index %workgroup_size_y = hal.interface.workgroup.size[1] : index @@ -487,11 +487,11 @@ func @dot_general_lowering() { %9 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 1)>(%arg0)[%workgroup_size_y] %10 = subtensor %4[%arg0, 0] [%9, 2] [1, 1] : tensor<1x2xf32> to tensor %11 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 3)>(%arg1)[%workgroup_size_x] - %12 = flow.dispatch.input.load %1, offsets = [%c0, %arg1], sizes = [%c2, %11], strides = [%c1, %c1] : !flow.dispatch.input<2x3xf32> -> tensor<2x?xf32> + %12 = flow.dispatch.tensor.load %1, offsets = [%c0, %arg1], sizes = [%c2, %11], strides = [%c1, %c1] : !flow.dispatch.tensor -> tensor<2x?xf32> %13 = linalg.init_tensor [%9, %11] : tensor %14 = linalg.fill(%13, %cst) : tensor, f32 -> tensor %15 = linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%10, %12 : tensor, tensor<2x?xf32>) outs(%14 : tensor) -> tensor - flow.dispatch.output.store %15, %2, offsets = [%arg0, %arg1], sizes = [%9, %11], strides = [%c1, %c1] : tensor -> !flow.dispatch.output<1x3xf32> + flow.dispatch.tensor.store %15, %2, offsets = [%arg0, %arg1], sizes = [%9, %11], strides = [%c1, %c1] : tensor -> !flow.dispatch.tensor } } return @@ -520,11 +520,11 @@ hal.interface @legacy_io attributes {sym_visibility = "private"} { func @slice_whole_stride_dispatch_0() { %c0 = constant 0 : index - %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.input - %1 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.output - %2 = flow.dispatch.input.load %0 : !flow.dispatch.input -> tensor + %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.tensor + %1 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.tensor + %2 = flow.dispatch.tensor.load %0 : !flow.dispatch.tensor -> tensor %3 = subtensor %2[1, 0] [1, 4] [1, 1] : tensor to tensor<1x4xi32> - flow.dispatch.output.store %3, %1 : tensor<1x4xi32> -> !flow.dispatch.output + flow.dispatch.tensor.store %3, %1 : tensor<1x4xi32> -> !flow.dispatch.tensor return } hal.interface @legacy_io attributes {sym_visibility = "private"} { @@ -542,15 +542,15 @@ hal.interface @legacy_io attributes {sym_visibility = "private"} { func @subtensor_insert() { %c0 = constant 0 : index %c1 = constant 1 : index - %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.input - %1 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : !flow.dispatch.input - %2 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.output - %3 = flow.dispatch.input.load %0 : !flow.dispatch.input -> tensor - %4 = flow.dispatch.input.load %1 : !flow.dispatch.input -> tensor + %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.tensor + %1 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : !flow.dispatch.tensor + %2 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.tensor + %3 = flow.dispatch.tensor.load %0 : !flow.dispatch.tensor -> tensor + %4 = flow.dispatch.tensor.load %1 : !flow.dispatch.tensor -> tensor %5 = dim %3, %c0 : tensor %6 = dim %3, %c1 : tensor %7 = subtensor_insert %3 into %4[3, 4] [%5, %6] [1, 1] : tensor into tensor - flow.dispatch.output.store %7, %2 : tensor -> !flow.dispatch.output + flow.dispatch.tensor.store %7, %2 : tensor -> !flow.dispatch.tensor return } hal.interface @legacy_io attributes {sym_visibility = "private"} { @@ -574,13 +574,13 @@ hal.interface @legacy_io attributes {sym_visibility = "private"} { func @tensor_extract() { %c0 = constant 0 : index - %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.input - %1 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.output<3x9xi32> + %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.tensor + %1 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.tensor %2 = linalg.init_tensor [3, 9] : tensor<3x9xi32> - %3 = flow.dispatch.input.load %0 : !flow.dispatch.input -> tensor + %3 = flow.dispatch.tensor.load %0 : !flow.dispatch.tensor -> tensor %4 = tensor.extract %3[] : tensor - %5 = linalg.fill(%2, %4) : tensor<3x9xi32>, i32 -> tensor<3x9xi32> - flow.dispatch.output.store %5, %1 : tensor<3x9xi32> -> !flow.dispatch.output<3x9xi32> + %5 = linalg.fill(%2, %4) : tensor<3x9xi32>, i32 -> tensor<3x9xi32> + flow.dispatch.tensor.store %5, %1 : tensor<3x9xi32> -> !flow.dispatch.tensor return } hal.interface @legacy_io attributes {sym_visibility = "private"} { diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/linalg_vectorize.mlir b/iree/compiler/Conversion/LinalgToLLVM/test/linalg_vectorize.mlir index 0cccf12665ad..8512b0c5de88 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/test/linalg_vectorize.mlir +++ b/iree/compiler/Conversion/LinalgToLLVM/test/linalg_vectorize.mlir @@ -4,9 +4,9 @@ // CHECK-DAG: %[[C0:.*]] = constant 0 : index // CHECK-DAG: %[[C1:.*]] = constant 1 : index // CHECK-DAG: %[[C2:.*]] = constant 2 : index -// CHECK: %[[I0:.*]] = flow.dispatch.input.load {{.*}} : !flow.dispatch.input<2x3xf32> -> tensor<2x3xf32> -// CHECK: %[[I1:.*]] = flow.dispatch.input.load {{.*}} : !flow.dispatch.input<3x4xf32> -> tensor<3x1xf32> -// CHECK: %[[I2:.*]] = flow.dispatch.input.load {{.*}} : !flow.dispatch.input<2x4xf32> -> tensor<2x1xf32> +// CHECK: %[[I0:.*]] = flow.dispatch.tensor.load {{.*}} : !flow.dispatch.tensor -> tensor<2x3xf32> +// CHECK: %[[I1:.*]] = flow.dispatch.tensor.load {{.*}} : !flow.dispatch.tensor -> tensor<3x1xf32> +// CHECK: %[[I2:.*]] = flow.dispatch.tensor.load {{.*}} : !flow.dispatch.tensor -> tensor<2x1xf32> // CHECK: %[[V0:.*]] = vector.transfer_read %[[I0]][%[[C0]], %[[C0]]], {{.*}} : tensor<2x3xf32>, vector<1x1xf32> // CHECK: %[[V1:.*]] = vector.transfer_read %[[I0]][%[[C0]], %[[C1]]], {{.*}} : tensor<2x3xf32>, vector<1x1xf32> // CHECK: %[[V2:.*]] = vector.transfer_read %[[I0]][%[[C0]], %[[C2]]], {{.*}} : tensor<2x3xf32>, vector<1x1xf32> @@ -26,21 +26,21 @@ // CHECK: %[[D5:.*]] = vector.contract {{.*}} %[[V5]], %[[V8]], %[[D4]] : vector<1x1xf32>, vector<1x1xf32> into vector<1x1xf32> // CHECK: %[[W0:.*]] = vector.transfer_write %[[D2]], %[[I2]][%[[C0]], %[[C0]]] {masked = [false, false]} : vector<1x1xf32>, tensor<2x1xf32> // CHECK: %[[W1:.*]] = vector.transfer_write %[[D5]], %[[W0]][%[[C1]], %[[C0]]] {masked = [false, false]} : vector<1x1xf32>, tensor<2x1xf32> -// CHECK: flow.dispatch.output.store %[[W1]] +// CHECK: flow.dispatch.tensor.store %[[W1]] func @tensor_dispatch_0() { %c0 = constant 0 : index %c3 = constant 3 : index %c1 = constant 1 : index %c2 = constant 1 : index - %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.input<2x3xf32> - %1 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : !flow.dispatch.input<3x4xf32> - %2 = hal.interface.binding.subspan @legacy_io::@arg2[%c0] : !flow.dispatch.input<2x4xf32> - %3 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.output<2x4xf32> - %4 = flow.dispatch.input.load %0, offsets = [%c0, %c0], sizes = [%c2, %c3], strides = [%c1, %c1] : !flow.dispatch.input<2x3xf32> -> tensor<2x3xf32> - %5 = flow.dispatch.input.load %1, offsets = [%c0, %c0], sizes = [%c3, %c1], strides = [%c1, %c1] : !flow.dispatch.input<3x4xf32> -> tensor<3x1xf32> - %6 = flow.dispatch.input.load %2, offsets = [%c0, %c0], sizes = [%c2, %c1], strides = [%c1, %c1] : !flow.dispatch.input<2x4xf32> -> tensor<2x1xf32> + %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.tensor + %1 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : !flow.dispatch.tensor + %2 = hal.interface.binding.subspan @legacy_io::@arg2[%c0] : !flow.dispatch.tensor + %3 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.tensor + %4 = flow.dispatch.tensor.load %0, offsets = [%c0, %c0], sizes = [%c2, %c3], strides = [%c1, %c1] : !flow.dispatch.tensor -> tensor<2x3xf32> + %5 = flow.dispatch.tensor.load %1, offsets = [%c0, %c0], sizes = [%c3, %c1], strides = [%c1, %c1] : !flow.dispatch.tensor -> tensor<3x1xf32> + %6 = flow.dispatch.tensor.load %2, offsets = [%c0, %c0], sizes = [%c2, %c1], strides = [%c1, %c1] : !flow.dispatch.tensor -> tensor<2x1xf32> %7 = linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%4, %5 : tensor<2x3xf32>, tensor<3x1xf32>) outs(%6 : tensor<2x1xf32>) -> tensor<2x1xf32> - flow.dispatch.output.store %7, %3, offsets = [%c0, %c0], sizes = [%c2, %c1], strides = [%c1, %c1] : tensor<2x1xf32> -> !flow.dispatch.output<2x4xf32> + flow.dispatch.tensor.store %7, %3, offsets = [%c0, %c0], sizes = [%c2, %c1], strides = [%c1, %c1] : tensor<2x1xf32> -> !flow.dispatch.tensor return } diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/materialize_launch_configuration.mlir b/iree/compiler/Conversion/LinalgToLLVM/test/materialize_launch_configuration.mlir index e2759c52151e..5c6018f3ef22 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/test/materialize_launch_configuration.mlir +++ b/iree/compiler/Conversion/LinalgToLLVM/test/materialize_launch_configuration.mlir @@ -9,8 +9,8 @@ hal.executable @matmul_tensors attributes {sym_visibility = "private"} { hal.executable.target @llvm_aot, filter="dylib*" { hal.executable.entry_point @matmul_tensors attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input, !flow.dispatch.input, - !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module { func @matmul_tensors() { %c0 = constant 0 : index @@ -97,8 +97,8 @@ hal.executable @add attributes {sym_visibility = "private"} { hal.executable.target @llvm_aot, filter="dylib*" { hal.executable.entry_point @add attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input, !flow.dispatch.input, - !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module { func @add() { %c0 = constant 0 : index diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/matmul_vectorization.mlir b/iree/compiler/Conversion/LinalgToLLVM/test/matmul_vectorization.mlir index 20979c8e0b86..5260df30969c 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/test/matmul_vectorization.mlir +++ b/iree/compiler/Conversion/LinalgToLLVM/test/matmul_vectorization.mlir @@ -10,8 +10,8 @@ hal.executable @dynamic_matmul attributes {sym_visibility = "private"} { hal.executable.target @llvm_aot, filter="dylib*" { hal.executable.entry_point @matmul_128x128x128 attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input<128x128xf32>, !flow.dispatch.input<128x128xf32>, - !flow.dispatch.output<128x128xf32>) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module { func @matmul_128x128x128(%arg0 : memref<128x128xf32>, %arg1: memref<128x128xf32>, %arg2: memref<128x128xf32>) { linalg.matmul ins(%arg0, %arg1 : memref<128x128xf32>, memref<128x128xf32>) outs(%arg2 : memref<128x128xf32>) @@ -91,8 +91,8 @@ hal.executable @dynamic_matmul_i8_i8_i32 attributes {sym_visibility = "private"} hal.executable.target @llvm_aot, filter="dylib*" { hal.executable.entry_point @matmul_i8_i8_i32_128x128x128 attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input<128x128xi8>, !flow.dispatch.input<128x128xi8>, - !flow.dispatch.output<128x128xi32>) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module { func @matmul_i8_i8_i32_128x128x128(%arg0 : memref<128x128xi8>, %arg1: memref<128x128xi8>, %arg2: memref<128x128xi32>) { linalg.matmul_i8_i8_i32 ins(%arg0, %arg1 : memref<128x128xi8>, memref<128x128xi8>) outs(%arg2 : memref<128x128xi32>) diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/tile_and_distribute.mlir b/iree/compiler/Conversion/LinalgToLLVM/test/tile_and_distribute.mlir index a430f6334824..ed84b770f19a 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/test/tile_and_distribute.mlir +++ b/iree/compiler/Conversion/LinalgToLLVM/test/tile_and_distribute.mlir @@ -10,8 +10,8 @@ // hal.executable.target @llvm_aot, filter="dylib*" { // hal.executable.entry_point @dynamic_matmul attributes { // interface = @legacy_io, ordinal = 0 : i32, -// signature = (!flow.dispatch.input, !flow.dispatch.input, -// !flow.dispatch.output) -> ()} +// signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, +// !flow.dispatch.tensor) -> ()} // module { // func @dynamic_matmul(%lhs: memref, %rhs: memref, %result: memref) { // linalg.matmul ins(%lhs, %rhs : memref, memref) outs(%result : memref) @@ -58,8 +58,8 @@ hal.executable @static_matmul attributes {sym_visibility = "private"} { hal.executable.target @llvm_aot, filter="dylib*" { hal.executable.entry_point @static_matmul attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input<16x4xf32>, !flow.dispatch.input<4x8xf32>, - !flow.dispatch.output<16x8xf32>) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module { func @static_matmul(%lhs: memref<16x4xf32>, %rhs: memref<4x8xf32>, %result: memref<16x8xf32>) { linalg.matmul ins(%lhs, %rhs : memref<16x4xf32>, memref<4x8xf32>) outs(%result : memref<16x8xf32>) diff --git a/iree/compiler/Conversion/LinalgToNVVM/test/pipeline_test.mlir b/iree/compiler/Conversion/LinalgToNVVM/test/pipeline_test.mlir index 170dee689660..9d6a7fd3a645 100644 --- a/iree/compiler/Conversion/LinalgToNVVM/test/pipeline_test.mlir +++ b/iree/compiler/Conversion/LinalgToNVVM/test/pipeline_test.mlir @@ -1,6 +1,6 @@ // RUN: iree-opt -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-hlo-to-nvvm-pipeline))" %s | IreeFileCheck %s -// Verify that a simple element wise op gets lowered succefully all the way to +// Verify that a simple element wise op gets lowered succefully all the way to // nvvm/llvm dialect. hal.executable @simpleMath_ex_dispatch_0 { @@ -9,22 +9,22 @@ hal.executable @simpleMath_ex_dispatch_0 { hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard" } hal.executable.target @cuda, filter="cuda" { - hal.executable.entry_point @add_dispatch_0 attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (!flow.dispatch.input<16xf32>, !flow.dispatch.input<16xf32>, !flow.dispatch.output<16xf32>) -> ()} + hal.executable.entry_point @add_dispatch_0 attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module { func @add_dispatch_0() { %c0 = constant 0 : index - %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.input<16xf32> - %1 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : !flow.dispatch.input<16xf32> - %2 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.output<16xf32> + %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.tensor + %1 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : !flow.dispatch.tensor + %2 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.tensor %3 = linalg.init_tensor [16] : tensor<16xf32> - %4 = flow.dispatch.input.load %0 : !flow.dispatch.input<16xf32> -> tensor<16xf32> - %5 = flow.dispatch.input.load %1 : !flow.dispatch.input<16xf32> -> tensor<16xf32> + %4 = flow.dispatch.tensor.load %0 : !flow.dispatch.tensor -> tensor<16xf32> + %5 = flow.dispatch.tensor.load %1 : !flow.dispatch.tensor -> tensor<16xf32> %6 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%4, %5 : tensor<16xf32>, tensor<16xf32>) outs(%3 : tensor<16xf32>) { ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): // no predecessors %7 = addf %arg0, %arg1 : f32 linalg.yield %7 : f32 } -> tensor<16xf32> - flow.dispatch.output.store %6, %2 : tensor<16xf32> -> !flow.dispatch.output<16xf32> + flow.dispatch.tensor.store %6, %2 : tensor<16xf32> -> !flow.dispatch.tensor return } hal.interface @legacy_io attributes {sym_visibility = "private"} { diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/batch_matmul_vectorization.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/batch_matmul_vectorization.mlir index c597271a70c2..dd66e0b4c8cd 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/batch_matmul_vectorization.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/batch_matmul_vectorization.mlir @@ -9,8 +9,8 @@ hal.executable @batch_matmul_static_shape attributes {sym_visibility = "private" hal.executable.target @vulkan, filter="dylib*" { hal.executable.entry_point @batch_matmul_static_shape attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input, !flow.dispatch.input, - !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module attributes { spv.target_env = #spv.target_env<#spv.vce, !flow.dispatch.input, - !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module attributes { spv.target_env = #spv.target_env<#spv.vce, !flow.dispatch.input<3x3x16x32xf32>, !flow.dispatch.output<1x112x112x32xf32>) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module attributes {spv.target_env = #spv.target_env<#spv.vce, ARM:IntegratedGPU, {}>} { func @conv2d_static_shape() { %cst = constant 0.000000e+00 : f32 @@ -120,7 +120,7 @@ hal.executable @matmul_dynamic_shape attributes {sym_visibility = "private"} { hal.executable.target @vulkan_spirv, filter="vulkan*" { hal.executable.entry_point @matmul_dynamic_shape attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input<1x225x225x16xf32>, !flow.dispatch.input<3x3x16x32xf32>, !flow.dispatch.output<1x112x112x32xf32>) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module attributes {spv.target_env = #spv.target_env<#spv.vce, ARM:IntegratedGPU, {}>} { func @matmul_dynamic_shape() { %cst = constant 0.000000e+00 : f32 diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir index 15e1e9281c89..141435794077 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-convert-to-gpu))" -canonicalize -cse %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.target(iree-codegen-convert-to-gpu))' -canonicalize -cse %s | IreeFileCheck %s // TODO(GH-4901): Enable this test when linalg on tensors becomes default. // #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> @@ -11,8 +11,8 @@ // hal.executable.target @vulkan, filter="vulkan*" { // hal.executable.entry_point @parallel_4D attributes { // interface = @legacy_io, ordinal = 0 : i32, -// signature = (!flow.dispatch.input, !flow.dispatch.input, -// !flow.dispatch.output) -> ()} +// signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, +// !flow.dispatch.tensor) -> ()} // module attributes { // spv.target_env = // #spv.target_env<#spv.vce, !flow.dispatch.input, - !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module attributes { spv.target_env = #spv.target_env<#spv.vce, !flow.dispatch.input, - !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module attributes { spv.target_env = #spv.target_env<#spv.vce, !flow.dispatch.input, - !flow.dispatch.output<40xf32>) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module { func @reduce_sum() { %arg0 = iree.placeholder for "interace buffer" @@ -296,8 +296,8 @@ hal.executable @matmul attributes {sym_visibility = "private"} { hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @matmul attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input, !flow.dispatch.input, - !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module attributes { spv.target_env = #spv.target_env<#spv.vce, @@ -427,8 +427,8 @@ hal.executable @conv_no_padding attributes {sym_visibility = "private"} { hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @conv_no_padding attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input, !flow.dispatch.input, - !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module attributes { spv.target_env = #spv.target_env<#spv.vce, @@ -603,7 +603,7 @@ module { hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" } hal.executable.target @vulkan, filter="vulkan*" { - hal.executable.entry_point @pooling_nhwc_max attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (!flow.dispatch.input<2x16x16x6xf32>, !flow.dispatch.input<1x3x4x2xf32>, !flow.dispatch.output<2x14x13x5xf32>) -> ()} { + hal.executable.entry_point @pooling_nhwc_max attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} { ^bb0(%arg0: index, %arg1: index, %arg2: index): // no predecessors %c4 = constant 4 : index %c1 = constant 1 : index diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir index f5a2b68858ba..1c0e06490836 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-spirv-linalg-tile-and-distribute,iree-codegen-linalg-tile-and-fuse))" -iree-spirv-enable-vectorization -canonicalize -cse %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.target(iree-codegen-spirv-linalg-tile-and-distribute,iree-codegen-linalg-tile-and-fuse))' -iree-spirv-enable-vectorization -canonicalize -cse %s | IreeFileCheck %s // TODO(GH-4901): Convert these tests back to use dynamic shapes when linalg on tensors becomes default. hal.executable @conv_no_padding attributes {sym_visibility = "private"} { @@ -10,8 +10,8 @@ hal.executable @conv_no_padding attributes {sym_visibility = "private"} { hal.executable.target @vulkan, filter="dylib*" { hal.executable.entry_point @conv_no_padding attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input<3x4x6x14xf32>, !flow.dispatch.input<2x16x16x6xf32>, - !flow.dispatch.output<2x13x11x14xf32>) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module attributes { spv.target_env = #spv.target_env<#spv.vce, !flow.dispatch.input<50x75xf32>, - !flow.dispatch.output<25x75xf32>) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module attributes { spv.target_env = #spv.target_env<#spv.vce, !flow.dispatch.input<1x3x4x2xf32>, - !flow.dispatch.output<2x14x13x5xf32>) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module attributes { spv.target_env = #spv.target_env<#spv.vce, !flow.dispatch.input<50x75xf32>, - !flow.dispatch.output<25x75xf32>) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module attributes { spv.target_env = #spv.target_env<#spv.vce, !flow.dispatch.input<2x16x16x6xf32>, - !flow.dispatch.output<2x13x11x14xf32>) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module attributes { spv.target_env = #spv.target_env<#spv.vce, !flow.dispatch.input<50x75xf32>, - !flow.dispatch.output<25x75xf32>) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module attributes { spv.target_env = #spv.target_env<#spv.vce, !flow.dispatch.input, - !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module attributes {spv.target_env = #spv.target_env<#spv.vce, SwiftShader:CPU, {cooperative_matrix_properties_nv = [], max_compute_shared_memory_size = 16384 : i32, max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>, subgroup_size = 4 : i32}>} { func @matmul_tensors() { %c0 = constant 0 : index diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/materialize_launch_configuration2.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/materialize_launch_configuration2.mlir index 30b13fab2dee..f988e5ade311 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/materialize_launch_configuration2.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/materialize_launch_configuration2.mlir @@ -9,8 +9,8 @@ hal.executable @add attributes {sym_visibility = "private"} { hal.executable.target @vulkan_spirv, filter="vulkan*" { hal.executable.entry_point @add attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input, !flow.dispatch.input, - !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module attributes {spv.target_env = #spv.target_env<#spv.vce, SwiftShader:CPU, {cooperative_matrix_properties_nv = [], max_compute_shared_memory_size = 16384 : i32, max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>, subgroup_size = 4 : i32}>} { func @add() { %c0 = constant 0 : index diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_fused_vectorization.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_fused_vectorization.mlir index b07f50ad08ff..19e148b12277 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_fused_vectorization.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_fused_vectorization.mlir @@ -9,8 +9,8 @@ hal.executable @matmul_static_shape attributes {sym_visibility = "private"} { hal.executable.target @vulkan, filter="dylib*" { hal.executable.entry_point @matmul_static_shape attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input, !flow.dispatch.input, - !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module attributes { spv.target_env = #spv.target_env<#spv.vce, !flow.dispatch.input, - !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module attributes { spv.target_env = #spv.target_env<#spv.vce, !flow.dispatch.input, - !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module attributes { spv.target_env = #spv.target_env<#spv.vce, !flow.dispatch.input, - !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module attributes { spv.target_env = #spv.target_env<#spv.vce, !flow.dispatch.input, - !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module attributes { spv.target_env = #spv.target_env<#spv.vce %4 = alloc() : memref<1024x256xf32> linalg.fill(%4, %cst) : memref<1024x256xf32>, f32 - linalg.matmul ins(%1, %2 : memref<1024x512xf32>, memref<512x256xf32>) + linalg.matmul ins(%1, %2 : memref<1024x512xf32>, memref<512x256xf32>) outs(%4 : memref<1024x256xf32>) - linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"]} - ins(%4, %3 : memref<1024x256xf32>, memref<1024x256xf32>) + iterator_types = ["parallel", "parallel"]} + ins(%4, %3 : memref<1024x256xf32>, memref<1024x256xf32>) outs(%0 : memref<1024x256xf32>) { ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): // no predecessors %5 = addf %arg0, %arg1 : f32 diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test_cooperative_mat.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test_cooperative_mat.mlir index 2e3382f34090..d928cc6e3828 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test_cooperative_mat.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test_cooperative_mat.mlir @@ -10,8 +10,8 @@ hal.executable @matmul_static_shape attributes {sym_visibility = "private"} { hal.executable.target @vulkan, filter="dylib*" { hal.executable.entry_point @matmul_static_shape attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input, !flow.dispatch.input, - !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module attributes { spv.target_env = #spv.target_env<#spv.vce, !flow.dispatch.input<3x512x1xf32>, - !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module { // CHECK: func @kernel_fusable_fill_conv1d_ops // CHECK: linalg.fill @@ -55,8 +55,8 @@ hal.executable @kernel_fusable_fill_conv2d_ops attributes {sym_visiblity = "priv hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @kernel_fusable_fill_conv2d_ops attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input, !flow.dispatch.input<3x3x512x1xf32>, - !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module { // CHECK: func @kernel_fusable_fill_conv2d_ops // CHECK: linalg.fill @@ -102,8 +102,8 @@ hal.executable @kernel_fusable_fill_conv3d_ops attributes {sym_visiblity = "priv hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @kernel_fusable_fill_conv3d_ops attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input, !flow.dispatch.input<3x3x3x512x1xf32>, - !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module { // CHECK: func @kernel_fusable_fill_conv3d_ops // CHECK: linalg.fill @@ -149,8 +149,8 @@ hal.executable @kernel_fusable_fill_matmul_ops attributes {sym_visiblity = "priv hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @kernel_fusable_fill_matmul_ops attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input, !flow.dispatch.input<512x?xf32>, - !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module { // CHECK: func @kernel_fusable_fill_matmul_ops // CHECK: linalg.fill @@ -196,8 +196,8 @@ hal.executable @kernel_fusable_pooling attributes {sym_visiblity = "private"} { hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @kernel_fusable_pooling attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input, !flow.dispatch.input, - !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module { // CHECK: func @kernel_fusable_pooling() // CHECK: linalg.fill @@ -236,8 +236,8 @@ hal.executable @kernel attributes {sym_visiblity = "private"} { hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @kernel attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input, !flow.dispatch.input<3x3x512x1xf32>, - !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} // CHECK: hal.executable.entry_point @kernel_dispatch_0 // CHECK: hal.executable.entry_point @kernel_dispatch_1 // CHECK: module attributes {hal.entry_point_schedule = [@kernel_dispatch_0, @kernel_dispatch_1]} @@ -303,8 +303,8 @@ hal.executable @kernel attributes {sym_visiblity = "private"} { hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @kernel attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input, !flow.dispatch.input<3x3x512x1xf32>, - !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} // CHECK: hal.executable.entry_point @kernel_dispatch_0 // CHECK: hal.executable.entry_point @kernel_dispatch_1 // CHECK: hal.executable.entry_point @kernel_dispatch_2 @@ -385,8 +385,8 @@ hal.executable @kernel attributes {sym_visiblity = "private"} { hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @kernel attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input<1x3x3x512xf32>, !flow.dispatch.input<3x3x512x1xf32>, - !flow.dispatch.output<1x1x1x512xf32>) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} // CHECK-NOT: hal.entry_point_schedule module { // CHECK-LABEL: @kernel() @@ -427,8 +427,8 @@ hal.executable @kernel attributes {sym_visiblity = "private"} { hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @kernel attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input, !flow.dispatch.input<3x512x1xf32>, - !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module { // expected-error @+1 {{cannot separate Linalg/Parallel ops into multiple kernels}} func @kernel() { @@ -465,7 +465,7 @@ hal.executable @subview_interleaved attributes {sym_visiblity = "private"} { hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @subview_interleaved attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input<18x12xf32>, !flow.dispatch.output<12x4xf32>) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module { func @subview_interleaved() { %cst = constant 0.000000e+00 : f32 @@ -516,8 +516,8 @@ hal.executable @reshape_interleaved attributes {sym_visiblity = "private"} { hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @reshape_interleaved attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input<2x4xf32>, !flow.dispatch.output<1x2x4xf32>, - !flow.dispatch.output<2x4xf32>) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module { func @reshape_interleaved() { %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<2x4xf32> @@ -576,8 +576,8 @@ hal.executable @predict_ex_dispatch_0 attributes {sym_visiblity = "private"} { hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @predict_ex_dispatch_0 attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input<1x512x1xf32>, !flow.dispatch.input<4x8x16xf32>, - !flow.dispatch.output<4x8x16xf32>, !flow.dispatch.output<4x8x16xf32>) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module { func @predict_ex_dispatch_0() { %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<1x512x1xf32> @@ -637,8 +637,8 @@ hal.executable @kernel_fusable_fill_matmul_generic_ops attributes {sym_visiblity hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @kernel_fusable_fill_matmul_generic_ops attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input, !flow.dispatch.input<512x?xf32>, - !flow.dispatch.input, !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module { // CHECK: func @kernel_fusable_fill_matmul_generic_ops // CHECK: linalg.fill diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_conv.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_conv.mlir index 165cd52394b8..18c360113a14 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_conv.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_conv.mlir @@ -9,7 +9,7 @@ hal.executable @conv_static_shape_f32 attributes {sym_visibility = "private"} { hal.executable.target @vulkan_spirv, filter="vulkan*" { hal.executable.entry_point @conv_static_shape_f32 attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input<1x225x225x16xf32>, !flow.dispatch.input<3x3x16x32xf32>, !flow.dispatch.output<1x112x112x32xf32>) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module attributes {spv.target_env = #spv.target_env<#spv.vce, ARM:IntegratedGPU, {}>} { func @conv_static_shape_f32() { %cst = constant 0.000000e+00 : f32 @@ -97,7 +97,7 @@ hal.executable @depthwise_conv_static_shape_f32 attributes {sym_visibility = "pr hal.executable.target @vulkan_spirv, filter="vulkan*" { hal.executable.entry_point @depthwise_conv_static_shape_f32 attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input<1x225x225x16xf32>, !flow.dispatch.input<3x3x16x32xf32>, !flow.dispatch.output<1x112x112x32xf32>) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module attributes {spv.target_env = #spv.target_env<#spv.vce, ARM:IntegratedGPU, {}>} { func @depthwise_conv_static_shape_f32() { %cst = constant 0.000000e+00 : f32 diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_matmul.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_matmul.mlir index 74d43514e568..3f17d95e8809 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_matmul.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_matmul.mlir @@ -9,7 +9,7 @@ hal.executable @matmul_static_shape_f16 attributes {sym_visibility = "private"} hal.executable.target @vulkan_spirv, filter="vulkan*" { hal.executable.entry_point @matmul_static_shape_f16 attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input<1x225x225x16xf32>, !flow.dispatch.input<3x3x16x32xf32>, !flow.dispatch.output<1x112x112x32xf32>) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module attributes {spv.target_env = #spv.target_env<#spv.vce, ARM:IntegratedGPU, {}>} { func @matmul_static_shape_f16() { %cst = constant 0.000000e+00 : f16 diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir index c96ce203934f..fe0fb852028e 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir @@ -10,8 +10,8 @@ hal.executable @matmul_tile attributes {sym_visibility = "private"} { hal.executable.target @vulkan, filter="dylib*" { hal.executable.entry_point @matmul_tile attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input<25x50xf32>, !flow.dispatch.input<50x75xf32>, - !flow.dispatch.output<25x75xf32>) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module attributes { spv.target_env = #spv.target_env<#spv.vce, @@ -68,8 +68,8 @@ hal.executable @conv_no_padding_tile attributes {sym_visibility = "private"} { hal.executable.target @vulkan, filter="dylib*" { hal.executable.entry_point @conv_no_padding_tile attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input<3x4x6x14xf32>, !flow.dispatch.input<2x16x16x6xf32>, - !flow.dispatch.output<2x13x11x14xf32>) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module attributes { spv.target_env = #spv.target_env<#spv.vce, diff --git a/iree/compiler/Dialect/Flow/IR/FlowBase.td b/iree/compiler/Dialect/Flow/IR/FlowBase.td index 9e1b62ae8bb0..c2a1f515329d 100644 --- a/iree/compiler/Dialect/Flow/IR/FlowBase.td +++ b/iree/compiler/Dialect/Flow/IR/FlowBase.td @@ -101,29 +101,58 @@ def FLOW_Workload : AnyTypeOf<[Index]> { }]; } -def FLOW_DispatchInput : DialectType< +def FLOW_TensorAccessCanReadPred : CPred<[{ + $_self.cast().getAccess() == + IREE::Flow::TensorAccess::ReadOnly || + $_self.cast().getAccess() == + IREE::Flow::TensorAccess::ReadWrite +}]>; + +def FLOW_TensorAccessCanWritePred : CPred<[{ + $_self.cast().getAccess() == + IREE::Flow::TensorAccess::WriteOnly || + $_self.cast().getAccess() == + IREE::Flow::TensorAccess::ReadWrite +}]>; + +def FLOW_DispatchTensor : DialectType< FLOW_Dialect, - CPred<"$_self.isa()">, - "dispatch.input"> { + CPred<"$_self.isa()">, + "dispatch.tensor"> { let description = [{ - A placeholder for a dispatch region input operand. This can be used to query - the metadata about the input (such as its shape) as well as load from the - backing tensor representation. + A placeholder for a dispatch region input/output operand. This can be used + to query the metadata about the tensor (such as its shape) as well as both + load and store from the backing tensor representation. }]; } -def FLOW_DispatchOutput : DialectType< +def FLOW_ReadableDispatchTensor : DialectType< FLOW_Dialect, - CPred<"$_self.isa()">, - "dispatch.output"> { + And<[ + CPred<"$_self.isa()">, + FLOW_TensorAccessCanReadPred, + ]>, + "dispatch.tensor"> { let description = [{ - A placeholder for a dispatch region output result. This can be used to - query the metadata about the output (such as its shape) as well as store - into the backing tensor representation. + A placeholder for a dispatch region input operand. This can be used + to query the metadata about the tensor (such as its shape) as well as load + from the backing tensor representation. }]; } -def FLOW_DispatchIO : AnyTypeOf<[FLOW_DispatchInput, FLOW_DispatchOutput]>; +def FLOW_WritableDispatchTensor : DialectType< + FLOW_Dialect, + And<[ + CPred<"$_self.isa()">, + FLOW_TensorAccessCanWritePred, + ]>, + "dispatch.tensor"> { + let description = [{ + A placeholder for a dispatch region output operand. This can be used + to query the metadata about the tensor (such as its shape) as well as store + to the backing tensor representation. + }]; +} // Use no padding and clamp the window to the valid area, possibly stopping // early prior to having covered all data. diff --git a/iree/compiler/Dialect/Flow/IR/FlowDialect.cpp b/iree/compiler/Dialect/Flow/IR/FlowDialect.cpp index 598dddd24ce6..be5f5969d4c8 100644 --- a/iree/compiler/Dialect/Flow/IR/FlowDialect.cpp +++ b/iree/compiler/Dialect/Flow/IR/FlowDialect.cpp @@ -48,7 +48,7 @@ struct FlowFolderInterface : public DialectFoldInterface { FlowDialect::FlowDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context, TypeID::get()) { addInterfaces(); - addTypes(); + addTypes(); #define GET_OP_LIST addOperations< @@ -70,10 +70,8 @@ Operation *FlowDialect::materializeConstant(OpBuilder &builder, Attribute value, Type FlowDialect::parseType(DialectAsmParser &parser) const { llvm::StringRef spec = parser.getFullSymbolSpec(); - if (succeeded(parser.parseOptionalKeyword("dispatch.input"))) { - return DispatchInputType::parse(parser); - } else if (succeeded(parser.parseOptionalKeyword("dispatch.output"))) { - return DispatchOutputType::parse(parser); + if (succeeded(parser.parseOptionalKeyword("dispatch.tensor"))) { + return DispatchTensorType::parse(parser); } parser.emitError(parser.getCurrentLocation()) << "unknown Flow type: " << spec; @@ -81,10 +79,8 @@ Type FlowDialect::parseType(DialectAsmParser &parser) const { } void FlowDialect::printType(Type type, DialectAsmPrinter &p) const { - if (auto inputType = type.dyn_cast()) { + if (auto inputType = type.dyn_cast()) { IREE::Flow::printType(inputType, p); - } else if (auto outputType = type.dyn_cast()) { - IREE::Flow::printType(outputType, p); } else { llvm_unreachable("unknown Flow type"); } diff --git a/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp b/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp index a3dbd3da19fe..acb5c45ff20c 100644 --- a/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp +++ b/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp @@ -18,10 +18,12 @@ #include "iree/compiler/Dialect/Flow/IR/FlowDialect.h" #include "iree/compiler/Dialect/Flow/IR/FlowOpUtils.h" #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" +#include "iree/compiler/Dialect/IREE/IR/IREETypes.h" #include "iree/compiler/Dialect/Shape/IR/ShapeOps.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/StringExtras.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -89,9 +91,112 @@ static SmallVector refreshDimsOnTypeChange( // Streams //===----------------------------------------------------------------------===// +namespace { + +// Returns true if the given |value| is used again after |updateOp| consumes it. +static bool hasUsersInStreamAfterUpdate(Value value, Operation *updateOp) { + for (auto user : value.getUsers()) { + if (user == updateOp) continue; + if (user->isBeforeInBlock(updateOp)) continue; + return true; + } + return false; +} + +/// Inserts clones into the stream as required by tied results. +/// This is required to preserve the immutable tensor semantics required by the +/// SSA use-def chain. +/// +/// Example: +/// %0 = flow.dispatch +/// // %0 will be updated in-place and renamed %1: +/// %1 = flow.dispatch %0 -> %0 +/// // The original value of %0 (aka %1) is required but is not valid! +/// %2 = flow.dispatch %0 +/// -> +/// %0 = flow.dispatch +/// // Capture the value of %0 before it is modified: +/// %clone = flow.tensor.clone %0 +/// // Update %0 in-place and rename to %1, safe as %0 now has one use: +/// %1 = flow.dispatch %0 -> %0 +/// // Use the cloned %0 value: +/// %2 = flow.dispatch %clone +struct InsertImmutabilityPreservingStreamClones + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExStreamFragmentOp op, + PatternRewriter &rewriter) const override { + bool didClone = + insertTiedClones(cast(op.getOperation()), rewriter); + for (auto &block : op.getClosureBodyRegion()) { + for (auto &innerOp : block) { + if (auto tiedOp = dyn_cast(innerOp)) { + didClone |= insertTiedClones(tiedOp, rewriter); + } + } + } + return didClone ? success() : failure(); + } + + bool insertTiedClones(TiedOpInterface tiedOp, + PatternRewriter &rewriter) const { + bool didClone = false; + for (unsigned resultIndex = 0; resultIndex < tiedOp->getNumResults(); + ++resultIndex) { + auto tiedOperandIndex = tiedOp.getTiedResultOperandIndex(resultIndex); + if (!tiedOperandIndex.hasValue()) continue; + auto tiedOperand = tiedOp->getOperand(tiedOperandIndex.getValue()); + if (hasUsersInStreamAfterUpdate(tiedOperand, tiedOp)) { + rewriter.setInsertionPointAfterValue(tiedOperand); + auto clonedOperand = rewriter.createOrFold( + tiedOperand.getLoc(), tiedOperand); + SmallPtrSet excludedOps; + excludedOps.insert(tiedOp.getOperation()); + excludedOps.insert(clonedOperand.getDefiningOp()); + tiedOperand.replaceAllUsesExcept(clonedOperand, excludedOps); + didClone = true; + } + } + return didClone; + } +}; + +/// Ties the results of streams to their operands when the stream operations are +/// tied throughout the entire body. +struct TieStreamResults : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExStreamFragmentOp op, + PatternRewriter &rewriter) const override { + assert(op.getRegion().getBlocks().size() == 1 && + "only one stream block supported"); + bool didModify = false; + op.walk([&](IREE::Flow::ReturnOp returnOp) { + for (auto result : llvm::enumerate(returnOp.getOperands())) { + if (op.getTiedResultOperandIndex(result.index()).hasValue()) { + continue; // Already tied. + } + auto baseValue = + IREE::TiedOpInterface::findTiedBaseValue(result.value()); + if (auto blockArg = baseValue.dyn_cast()) { + unsigned operandIndex = blockArg.getArgNumber(); + op.setTiedResultOperandIndex(result.index(), operandIndex); + didModify = true; + } + } + }); + return didModify ? success() : failure(); + } +}; + +} // namespace + void ExStreamFragmentOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert>(context); + results.insert(context); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -250,21 +355,21 @@ void DispatchWorkgroupsOp::getCanonicalizationPatterns( } //===----------------------------------------------------------------------===// -// flow.dispatch.input.load +// flow.dispatch.tensor.load //===----------------------------------------------------------------------===// namespace { // Some linalg patterns, due to being upstream, tend to introduce `dim` ops. // These generally fold with upstream patterns when tensors are involved, but -// when DispatchInputLoadOp's are involved (with dispatch tensor types), +// when DispatchTensorLoadOp's are involved (with dispatch tensor types), // then this starts to break down, which causes the `dim` ops to survive // arbitrarily late into the pipeline. Often, they keep alive -// DispatchInputLoadOp's that would otherwise be dead! +// DispatchTensorLoadOp's that would otherwise be dead! // // To fix this, we convert the `std.dim` ops to `flow.dispatch.shape` ops. // ``` -// dim(flow.dispatch.input.load(%x), %const) +// dim(flow.dispatch.tensor.load(%x), %const) // -> // shapex.ranked_dim(flow.dispatch.shape(%x), %const) // `` @@ -274,15 +379,14 @@ struct ConvertDimOfDispatchInputLoadToDispatchShape LogicalResult matchAndRewrite(DimOp op, PatternRewriter &rewriter) const override { - auto dispatchInputLoad = - op.memrefOrTensor().getDefiningOp(); - if (!dispatchInputLoad) return failure(); + auto loadOp = op.memrefOrTensor().getDefiningOp(); + if (!loadOp) return failure(); Optional constantIndex = op.getConstantIndex(); if (!constantIndex.hasValue()) return failure(); - auto rankedShape = rewriter.create( - op.getLoc(), dispatchInputLoad.source()); + auto rankedShape = + rewriter.create(op.getLoc(), loadOp.source()); rewriter.replaceOpWithNewOp(op, rankedShape, *constantIndex); return success(); @@ -291,7 +395,7 @@ struct ConvertDimOfDispatchInputLoadToDispatchShape } // namespace -void DispatchInputLoadOp::getCanonicalizationPatterns( +void DispatchTensorLoadOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); } @@ -456,10 +560,21 @@ OpFoldResult TensorSplatOp::fold(ArrayRef operands) { OpFoldResult TensorCloneOp::fold(ArrayRef operands) { if (operands[0]) { + // Constants always fold. return operands[0]; } - // TODO(benvanik): fold if clone device placements differ. - return operand(); + + // TODO(benvanik): elide clones when safe to do so. Right now clone is + // load-bearing to work around our lack of cross-stream scheduling. Clones are + // inserted to avoid mutating function arguments and any logic we perform here + // (without *also* checking all the conditions that may insert a clone) will + // just fight. + // + // Once the clones are not load-bearing we can remove them in all the normal + // cases (one user, no intervening uses between clone and consumers of + // operands, etc). + + return {}; } // Slices tensor from start to (start + length) exclusively at dim. @@ -601,7 +716,8 @@ struct FoldTensorUpdateOpWithCasts : public OpRewritePattern { updateOp.start_indices(), update, refreshDimsOnTypeChange(updateOp, updateOp.update().getType(), update.getType(), updateOp.update_dims(), - rewriter)); + rewriter), + updateOp.tied_operandsAttr()); rewriter.replaceOpWithNewOp( updateOp, updateOp.getResult().getType(), newOp.getResult()); return success(); diff --git a/iree/compiler/Dialect/Flow/IR/FlowOps.cpp b/iree/compiler/Dialect/Flow/IR/FlowOps.cpp index 4fff767d3b8a..34aeeba4512b 100644 --- a/iree/compiler/Dialect/Flow/IR/FlowOps.cpp +++ b/iree/compiler/Dialect/Flow/IR/FlowOps.cpp @@ -17,6 +17,7 @@ #include "iree/compiler/Dialect/Flow/IR/FlowOpUtils.h" #include "iree/compiler/Dialect/IREE/IR/IREETypes.h" #include "iree/compiler/Dialect/Shape/IR/Builders.h" +#include "llvm/ADT/BitVector.h" #include "llvm/ADT/StringExtras.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Attributes.h" @@ -67,12 +68,58 @@ static LogicalResult verifyOpDynamicDims(Operation *op, ValueRange values, return success(); } +//===----------------------------------------------------------------------===// +// custom +//===----------------------------------------------------------------------===// +// type{%dim0, %dim1} +// %arg0 + +static ParseResult parseTiedResult( + OpAsmParser &parser, Type &resultType, + SmallVectorImpl &resultDims, + ArrayAttr &tiedOperands) { + if (failed(parser.parseType(resultType))) return failure(); + if (auto shapedType = resultType.dyn_cast()) { + if (!shapedType.hasStaticShape()) { + SmallVector dynamicDims; + if (failed(parser.parseLBrace()) || + failed(parser.parseOperandList(dynamicDims, + shapedType.getNumDynamicDims(), + OpAsmParser::Delimiter::None)) || + failed(parser.parseRBrace())) { + return failure(); + } + resultDims.append(dynamicDims); + } + } + tiedOperands = parser.getBuilder().getIndexArrayAttr({0}); + return success(); +} + +static void printTiedResult(OpAsmPrinter &p, Operation *op, Type resultType, + ValueRange resultDims, ArrayAttr tiedOperands) { + p.printType(resultType); + if (auto shapedType = resultType.dyn_cast()) { + if (!shapedType.hasStaticShape()) { + if (resultDims.empty()) { + p << "{<>}"; + return; + } + p << "{"; + llvm::interleaveComma( + resultDims.take_front(shapedType.getNumDynamicDims()), p, + [&](Value value) { p.printOperand(value); }); + p << "}"; + } + } +} + //===----------------------------------------------------------------------===// // custom //===----------------------------------------------------------------------===// -// (type, type{%dim0, %dim1}, type) -> type{%dim2} +// (type, type{%dim0, %dim1}, type) -> (type{%dim2}, %operand4) -static ParseResult parseShapedTypeList( +static ParseResult parseShapedOperandList( OpAsmParser &parser, SmallVectorImpl &types, SmallVectorImpl &dims) { do { @@ -80,12 +127,15 @@ static ParseResult parseShapedTypeList( if (failed(parser.parseType(type))) return failure(); if (auto shapedType = type.dyn_cast()) { if (!shapedType.hasStaticShape()) { + SmallVector dynamicDims; if (failed(parser.parseLBrace()) || - failed(parser.parseOperandList(dims, shapedType.getNumDynamicDims(), + failed(parser.parseOperandList(dynamicDims, + shapedType.getNumDynamicDims(), OpAsmParser::Delimiter::None)) || failed(parser.parseRBrace())) { return failure(); } + dims.append(dynamicDims); } } types.push_back(type); @@ -93,60 +143,174 @@ static ParseResult parseShapedTypeList( return success(); } -static void printShapedTypeList(OpAsmPrinter &p, Operation *op, TypeRange types, - OperandRange dims) { - llvm::interleaveComma(types, p, [&](Type type) { - p.printType(type); - if (auto shapedType = type.dyn_cast()) { - if (!shapedType.hasStaticShape()) { - if (dims.empty()) { - p << "{<>}"; - return; +// Ties the |tiedResult| parsed operand back to a previously parsed operand. +// The type and any dynamic dimensions of the operand will be used for the +// result values and the operand index will be appended to |tiedOperandIndices|. +static ParseResult tieOperand( + OpAsmParser::OperandType tiedResult, OpAsmParser &parser, + ArrayRef operands, TypeRange operandTypes, + ArrayRef operandDims, + SmallVectorImpl &resultTypes, + SmallVectorImpl &resultDims, + SmallVectorImpl &tiedOperandIndices) { + int64_t operandIndex = TiedOpInterface::kUntiedIndex; + for (int64_t i = 0; i < operands.size(); ++i) { + if (operands[i].name == tiedResult.name) { + operandIndex = i; + break; + } + } + if (operandIndex == TiedOpInterface::kUntiedIndex) { + return parser.emitError(tiedResult.location, + "tied operand not found for result reference ") + << tiedResult.name; + } + + auto resultType = operandTypes[operandIndex]; + resultTypes.push_back(resultType); + tiedOperandIndices.push_back(operandIndex); + + auto shapedType = resultType.dyn_cast(); + if (shapedType) { + unsigned dimsIndex = 0; + for (unsigned i = 0; i < operandIndex; ++i) { + if (auto shapedType = operandTypes[i].dyn_cast()) { + dimsIndex += shapedType.getNumDynamicDims(); + } + } + resultDims.append(llvm::to_vector<4>( + operandDims.slice(dimsIndex, shapedType.getNumDynamicDims()))); + } + + return success(); +} + +static ParseResult parseShapedResultList( + OpAsmParser &parser, ArrayRef operands, + TypeRange operandTypes, ArrayRef operandDims, + SmallVectorImpl &resultTypes, + SmallVectorImpl &resultDims, + ArrayAttr &tiedOperands) { + SmallVector tiedOperandIndices; + do { + OpAsmParser::OperandType tiedResult; + auto res = parser.parseOptionalOperand(tiedResult); + if (res.hasValue() && succeeded(res.getValue())) { + if (failed(tieOperand(tiedResult, parser, operands, operandTypes, + operandDims, resultTypes, resultDims, + tiedOperandIndices))) { + return failure(); + } + } else { + Type type; + if (failed(parser.parseType(type))) return failure(); + if (auto shapedType = type.dyn_cast()) { + if (!shapedType.hasStaticShape()) { + SmallVector dynamicDims; + if (failed(parser.parseLBrace()) || + failed(parser.parseOperandList(dynamicDims, + shapedType.getNumDynamicDims(), + OpAsmParser::Delimiter::None)) || + failed(parser.parseRBrace())) { + return failure(); + } + resultDims.append(dynamicDims); } - p << "{"; - llvm::interleaveComma(dims.take_front(shapedType.getNumDynamicDims()), - p, [&](Value value) { p.printOperand(value); }); - p << "}"; - dims = dims.drop_front(shapedType.getNumDynamicDims()); } + resultTypes.push_back(type); + tiedOperandIndices.push_back(TiedOpInterface::kUntiedIndex); } - }); + } while (succeeded(parser.parseOptionalComma())); + if (!tiedOperandIndices.empty()) { + tiedOperands = parser.getBuilder().getIndexArrayAttr(tiedOperandIndices); + } + return success(); } static ParseResult parseShapedFunctionType( - OpAsmParser &parser, SmallVectorImpl &operandTypes, + OpAsmParser &parser, ArrayRef operands, + SmallVectorImpl &operandTypes, SmallVectorImpl &operandDims, SmallVectorImpl &resultTypes, - SmallVectorImpl &resultDims) { + SmallVectorImpl &resultDims, + ArrayAttr &tiedOperands) { if (failed(parser.parseLParen())) return failure(); if (failed(parser.parseOptionalRParen())) { - if (failed(parseShapedTypeList(parser, operandTypes, operandDims)) || + if (failed(parseShapedOperandList(parser, operandTypes, operandDims)) || failed(parser.parseRParen())) { return failure(); } } if (failed(parser.parseArrow())) return failure(); if (succeeded(parser.parseOptionalLParen())) { - if (failed(parseShapedTypeList(parser, resultTypes, resultDims)) || + if (failed(parseShapedResultList(parser, operands, operandTypes, + operandDims, resultTypes, resultDims, + tiedOperands)) || failed(parser.parseRParen())) { return failure(); } - } else if (failed(parseShapedTypeList(parser, resultTypes, resultDims))) { - return failure(); + } else { + if (failed(parseShapedResultList(parser, operands, operandTypes, + operandDims, resultTypes, resultDims, + tiedOperands))) { + return failure(); + } } return success(); } static void printShapedFunctionType(OpAsmPrinter &p, Operation *op, - TypeRange operandTypes, + ValueRange operands, TypeRange operandTypes, OperandRange operandDims, TypeRange resultTypes, - OperandRange resultDims) { + OperandRange resultDims, + ArrayAttr tiedOperands) { p << "("; - printShapedTypeList(p, op, operandTypes, operandDims); - p << ") -> ("; - printShapedTypeList(p, op, resultTypes, resultDims); - p << ")"; + llvm::interleaveComma(operandTypes, p, [&](Type type) { + p.printType(type); + if (auto shapedType = type.dyn_cast()) { + if (!shapedType.hasStaticShape()) { + if (operandDims.empty()) { + p << "{<>}"; + return; + } + p << "{"; + llvm::interleaveComma( + operandDims.take_front(shapedType.getNumDynamicDims()), p, + [&](Value value) { p.printOperand(value); }); + p << "}"; + operandDims = operandDims.drop_front(shapedType.getNumDynamicDims()); + } + } + }); + p << ") -> "; + if (resultTypes.size() != 1) p << "("; + auto tiedOp = cast(op); + for (unsigned i = 0; i < resultTypes.size(); ++i) { + auto tiedOperand = tiedOp.getTiedResultOperandIndex(i); + if (tiedOperand.hasValue()) { + p.printOperand(op->getOperand(tiedOperand.getValue())); + } else { + auto type = resultTypes[i]; + p.printType(type); + if (auto shapedType = type.dyn_cast()) { + if (!shapedType.hasStaticShape()) { + if (resultDims.empty()) { + p << "{<>}"; + return; + } + p << "{"; + llvm::interleaveComma( + resultDims.take_front(shapedType.getNumDynamicDims()), p, + [&](Value value) { p.printOperand(value); }); + p << "}"; + resultDims = resultDims.drop_front(shapedType.getNumDynamicDims()); + } + } + } + if (i < resultTypes.size() - 1) p << ", "; + } + if (resultTypes.size() != 1) p << ")"; } //===----------------------------------------------------------------------===// @@ -674,6 +838,7 @@ void DispatchWorkgroupsOp::build(OpBuilder &builder, OperationState &state, ValueRange workgroupCount, TypeRange resultTypes, ValueRange resultDims, ValueRange operands, ValueRange operandDims, + ArrayRef tiedOperands, ArrayRef attributes) { state.addTypes(resultTypes); state.addOperands(workgroupCount); @@ -681,6 +846,9 @@ void DispatchWorkgroupsOp::build(OpBuilder &builder, OperationState &state, state.addOperands(operandDims); state.addOperands(resultDims); state.addAttributes(attributes); + state.attributes.erase(TiedOpInterface::getStorageAttrName()); + state.addAttribute(TiedOpInterface::getStorageAttrName(), + builder.getIndexArrayAttr(tiedOperands)); state.attributes.erase("operand_segment_sizes"); state.addAttribute("operand_segment_sizes", builder.getI32VectorAttr({ @@ -696,17 +864,36 @@ void DispatchWorkgroupsOp::build(OpBuilder &builder, OperationState &state, OpBuilder::InsertionGuard g(builder); builder.createBlock(body); // createBlock implicitly moves IP, RAII away... } - for (auto operand : operands) { - Type type = operand.getType(); + + llvm::BitVector operandAliases(llvm::size(operands), false); + llvm::BitVector resultAliases(llvm::size(resultTypes), false); + for (unsigned resultIndex = 0; resultIndex < tiedOperands.size(); + ++resultIndex) { + int64_t tiedOperandIndex = tiedOperands[resultIndex]; + if (tiedOperandIndex != TiedOpInterface::kUntiedIndex) { + operandAliases[tiedOperandIndex] = true; + resultAliases[resultIndex] = true; + } + } + + for (auto operand : llvm::enumerate(operands)) { + Type type = operand.value().getType(); if (auto tensorType = type.dyn_cast()) { - type = DispatchInputType::get(tensorType); + type = DispatchTensorType::get(operandAliases[operand.index()] + ? TensorAccess::ReadWrite + : TensorAccess::ReadOnly, + tensorType); } body->addArgument(type); } - for (auto resultType : resultTypes) { - Type type = resultType; + for (auto resultType : llvm::enumerate(resultTypes)) { + if (resultAliases[resultType.index()]) { + // Already handled by an aliased operand. + continue; + } + Type type = resultType.value(); if (auto tensorType = type.dyn_cast()) { - type = DispatchOutputType::get(tensorType); + type = DispatchTensorType::get(TensorAccess::WriteOnly, tensorType); } body->addArgument(type); } @@ -717,8 +904,6 @@ static ParseResult parseDispatchWorkgroupBody(OpAsmParser &parser, TypeRange operandTypes, TypeRange resultTypes, Region &body) { - auto loc = parser.getCurrentLocation(); - SmallVector regionArgs; SmallVector regionArgTypes; if (failed(parser.parseLParen())) { @@ -738,12 +923,6 @@ static ParseResult parseDispatchWorkgroupBody(OpAsmParser &parser, return failure(); } } - - if (regionArgs.size() != operandTypes.size() + resultTypes.size()) { - return parser.emitError(loc, - "region operand list required required to match " - "count of dispatch op operands + results"); - } return parser.parseRegion(body, regionArgs, regionArgTypes, /*enableNameShadowing=*/true); } @@ -808,11 +987,15 @@ DispatchWorkgroupsOp::cloneReplacementExcludingOperandsAndResults( excludeClosureOperandsAndResults(newOperandsValues, newOperandDims, excludedOperandIndices, newResultTypes, newResultDims, excludedResultIndices); - auto newOp = - OpBuilder(getContext()) - .create( - getLoc(), workgroup_count(), newResultTypes, newResultDims, - newOperandsValues, newOperandDims, getOperation()->getAttrs()); + SmallVector newTiedOperandIndices = + llvm::to_vector<4>(getTiedResultOperandIndices()); + excludeTiedOperandAndResultIndices( + excludedOperandIndices, excludedResultIndices, newTiedOperandIndices); + auto newOp = OpBuilder(getContext()) + .create( + getLoc(), workgroup_count(), newResultTypes, + newResultDims, newOperandsValues, newOperandDims, + newTiedOperandIndices, getOperation()->getAttrs()); auto &newBody = newOp.getClosureBodyRegion(); newBody.takeBody(getClosureBodyRegion()); newBody.front().eraseArguments(excludedOperandIndices); @@ -823,6 +1006,11 @@ DispatchWorkgroupsOp::cloneReplacementExcludingOperandsAndResults( return newOp; } +std::pair +DispatchWorkgroupsOp::getTiedOperandsIndexAndLength() { + return getODSOperandIndexAndLength(1); +} + //===----------------------------------------------------------------------===// // flow.dispatch.workgroup.* //===----------------------------------------------------------------------===// @@ -991,6 +1179,7 @@ void DispatchOp::build(OpBuilder &builder, OperationState &state, DispatchEntryOp entryPoint, ValueRange workgroupCount, TypeRange resultTypes, ValueRange resultDims, ValueRange operands, ValueRange operandDims, + ArrayRef tiedOperands, ArrayRef attributes) { StringRef executableOpSymName = entryPoint->getParentOp() @@ -1007,6 +1196,9 @@ void DispatchOp::build(OpBuilder &builder, OperationState &state, state.addOperands(operandDims); state.addOperands(resultDims); state.addAttributes(attributes); + state.attributes.erase(TiedOpInterface::getStorageAttrName()); + state.addAttribute(TiedOpInterface::getStorageAttrName(), + builder.getIndexArrayAttr(tiedOperands)); state.attributes.erase("operand_segment_sizes"); state.addAttribute("operand_segment_sizes", builder.getI32VectorAttr({ @@ -1045,6 +1237,10 @@ Value DispatchOp::buildResultRankedShape(unsigned idx, OpBuilder &builder) { result_dims(), builder); } +std::pair DispatchOp::getTiedOperandsIndexAndLength() { + return getODSOperandIndexAndLength(1); +} + //===----------------------------------------------------------------------===// // flow.tensor.* //===----------------------------------------------------------------------===// @@ -1121,7 +1317,7 @@ void TensorUpdateOp::build(OpBuilder &builder, OperationState &state, auto updateDims = Shape::buildOrFindDynamicDimsForValue(state.location, update, builder); build(builder, state, target.getType(), target, targetDims, startIndices, - update, updateDims); + update, updateDims, builder.getIndexArrayAttr({0})); } static LogicalResult verifyTensorUpdateOp(TensorUpdateOp op) { @@ -1151,6 +1347,19 @@ Value TensorUpdateOp::buildResultRankedShape(unsigned idx, OpBuilder &builder) { builder); } +Value TensorUpdateOp::getTiedResult(unsigned resultIndex) { + return IREE::TiedOpInterface::findTiedBaseValue(target()); +} + +::llvm::Optional TensorUpdateOp::getTiedResultOperandIndex( + unsigned resultIndex) { + return 0; // target +} + +SmallVector TensorUpdateOp::getTiedResultOperandIndices() { + return {0}; // target +} + //===----------------------------------------------------------------------===// // flow.ex.stream.fragment //===----------------------------------------------------------------------===// @@ -1158,12 +1367,16 @@ Value TensorUpdateOp::buildResultRankedShape(unsigned idx, OpBuilder &builder) { void ExStreamFragmentOp::build(OpBuilder &builder, OperationState &state, TypeRange resultTypes, ValueRange resultDims, ValueRange operands, ValueRange operandDims, + ArrayRef tiedOperands, ArrayRef attributes) { state.addTypes(resultTypes); state.addOperands(operands); state.addOperands(operandDims); state.addOperands(resultDims); state.addAttributes(attributes); + state.attributes.erase(TiedOpInterface::getStorageAttrName()); + state.addAttribute(TiedOpInterface::getStorageAttrName(), + builder.getIndexArrayAttr(tiedOperands)); state.attributes.erase("operand_segment_sizes"); state.addAttribute("operand_segment_sizes", builder.getI32VectorAttr({ @@ -1185,6 +1398,7 @@ static LogicalResult verifyExStreamFragmentOp(ExStreamFragmentOp op) { static ParseResult parseStreamFragmentBody(OpAsmParser &parser, TypeRange operandTypes, TypeRange resultTypes, + ArrayAttr tiedOperands, Region &body) { auto loc = parser.getCurrentLocation(); @@ -1224,16 +1438,21 @@ static ParseResult parseStreamFragmentBody(OpAsmParser &parser, static void printStreamFragmentBody(OpAsmPrinter &p, Operation *op, TypeRange operandTypes, - TypeRange resultTypes, Region &body) { + TypeRange resultTypes, + ArrayAttr tiedOperands, Region &body) { p << "("; llvm::interleaveComma(body.getArguments(), p, [&](BlockArgument arg) { p << arg; p << ": "; p << arg.getType(); }); - p << ") -> ("; - llvm::interleaveComma(resultTypes, p, [&](Type type) { p.printType(type); }); - p << ")"; + p << ") -> "; + if (resultTypes.size() != 1) p << "("; + for (unsigned i = 0; i < resultTypes.size(); ++i) { + p.printType(resultTypes[i]); + if (i < resultTypes.size() - 1) p << ", "; + } + if (resultTypes.size() != 1) p << ")"; p.printRegion(body, /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/true); } @@ -1278,11 +1497,15 @@ ExStreamFragmentOp::cloneReplacementExcludingOperandsAndResults( excludeClosureOperandsAndResults(newOperandsValues, newOperandDims, excludedOperandIndices, newResultTypes, newResultDims, excludedResultIndices); - auto newOp = - OpBuilder(getContext()) - .create(getLoc(), newResultTypes, newResultDims, - newOperandsValues, newOperandDims, - getOperation()->getAttrs()); + SmallVector newTiedOperandIndices = + llvm::to_vector<4>(getTiedResultOperandIndices()); + excludeTiedOperandAndResultIndices( + excludedOperandIndices, excludedResultIndices, newTiedOperandIndices); + auto newOp = OpBuilder(getContext()) + .create( + getLoc(), newResultTypes, newResultDims, + newOperandsValues, newOperandDims, newTiedOperandIndices, + getOperation()->getAttrs()); auto &newBody = newOp.getClosureBodyRegion(); newBody.takeBody(getClosureBodyRegion()); eraseRegionResults(newBody, excludedResultIndices); diff --git a/iree/compiler/Dialect/Flow/IR/FlowOps.td b/iree/compiler/Dialect/Flow/IR/FlowOps.td index a90c276d118b..a0f43bb9471b 100644 --- a/iree/compiler/Dialect/Flow/IR/FlowOps.td +++ b/iree/compiler/Dialect/Flow/IR/FlowOps.td @@ -17,6 +17,7 @@ include "iree/compiler/Dialect/Flow/IR/FlowBase.td" include "iree/compiler/Dialect/Flow/IR/FlowInterfaces.td" +include "iree/compiler/Dialect/IREE/IR/IREEInterfaces.td" include "iree/compiler/Dialect/Shape/IR/ShapeInterfaces.td" include "mlir/IR/OpAsmInterface.td" include "mlir/IR/SymbolInterfaces.td" @@ -241,6 +242,9 @@ def FLOW_DispatchWorkgroupsOp : FLOW_PureOp<"dispatch.workgroups", [ AttrSizedOperandSegments, SingleBlockImplicitTerminator<"IREE::Flow::ReturnOp">, DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, ]> { let summary = [{a dispatch of workgroups across an n-dimension grid}]; @@ -283,7 +287,8 @@ def FLOW_DispatchWorkgroupsOp : FLOW_PureOp<"dispatch.workgroups", [ Variadic:$workgroup_count, Variadic:$operands, FLOW_ShapeDynamicDims:$operand_dims, - FLOW_ShapeDynamicDims:$result_dims + FLOW_ShapeDynamicDims:$result_dims, + OptionalAttr:$tied_operands ); let results = (outs Variadic:$results @@ -294,8 +299,10 @@ def FLOW_DispatchWorkgroupsOp : FLOW_PureOp<"dispatch.workgroups", [ let assemblyFormat = [{ `[` $workgroup_count `]` `` `(` $operands `)` `:` - custom(type($operands), $operand_dims, - type($results), $result_dims) + custom(ref($operands), + type($operands), $operand_dims, + type($results), $result_dims, + $tied_operands) attr-dict-with-keyword `=` `\n` ` ` ` ` ` ` custom(ref(type($operands)), @@ -309,6 +316,7 @@ def FLOW_DispatchWorkgroupsOp : FLOW_PureOp<"dispatch.workgroups", [ "ValueRange":$workgroupCount, "TypeRange":$resultTypes, "ValueRange":$resultDims, "ValueRange":$operands, "ValueRange":$operandDims, + "ArrayRef":$tiedOperands, CArg<"ArrayRef", "{}">:$attributes)>, ]; @@ -467,7 +475,7 @@ def FLOW_DispatchShapeOp : FLOW_PureOp<"dispatch.shape", [ }]; let arguments = (ins - FLOW_DispatchIO:$source + FLOW_DispatchTensor:$source ); let results = (outs Shape_RankedShape:$result @@ -487,11 +495,11 @@ def FLOW_DispatchTieShapeOp : FLOW_PureOp<"dispatch.tie_shape"> { }]; let arguments = (ins - FLOW_DispatchIO:$operand, + FLOW_DispatchTensor:$operand, Shape_RankedShape:$shape ); let results = (outs - FLOW_DispatchIO:$result + FLOW_DispatchTensor:$result ); // TODO(benvanik): figure out a way to make this look like shapex.tie_shape. @@ -503,7 +511,7 @@ def FLOW_DispatchTieShapeOp : FLOW_PureOp<"dispatch.tie_shape"> { let hasCanonicalizer = 1; } -def FLOW_DispatchInputLoadOp : FLOW_PureOp<"dispatch.input.load", [ +def FLOW_DispatchTensorLoadOp : FLOW_PureOp<"dispatch.tensor.load", [ SameVariadicOperandSize, ]> { let summary = [{loads a tensor from a dispatch input placeholder}]; @@ -514,7 +522,7 @@ def FLOW_DispatchInputLoadOp : FLOW_PureOp<"dispatch.input.load", [ }]; let arguments = (ins - FLOW_DispatchInput:$source, + FLOW_ReadableDispatchTensor:$source, Variadic:$offsets, Variadic:$sizes, Variadic:$strides @@ -534,7 +542,7 @@ def FLOW_DispatchInputLoadOp : FLOW_PureOp<"dispatch.input.load", [ let hasCanonicalizer = 1; } -def FLOW_DispatchOutputStoreOp : FLOW_Op<"dispatch.output.store", [ +def FLOW_DispatchTensorStoreOp : FLOW_Op<"dispatch.tensor.store", [ SameVariadicOperandSize, ]> { let summary = [{stores a tensor into a dispatch output placeholder}]; @@ -546,7 +554,7 @@ def FLOW_DispatchOutputStoreOp : FLOW_Op<"dispatch.output.store", [ let arguments = (ins AnyRankedTensor:$value, - FLOW_DispatchOutput:$target, + FLOW_WritableDispatchTensor:$target, Variadic:$offsets, Variadic:$sizes, Variadic:$strides @@ -655,6 +663,9 @@ def FLOW_DispatchEntryOp : FLOW_Op<"dispatch.entry", [ def FLOW_DispatchOp : FLOW_PureOp<"dispatch", [ AttrSizedOperandSegments, FLOW_StreamableOp, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, ]> { let summary = [{a dispatch of workgroups across an n-dimension grid}]; @@ -669,7 +680,8 @@ def FLOW_DispatchOp : FLOW_PureOp<"dispatch", [ SymbolRefAttr:$entry_point, Variadic:$operands, FLOW_ShapeDynamicDims:$operand_dims, - FLOW_ShapeDynamicDims:$result_dims + FLOW_ShapeDynamicDims:$result_dims, + OptionalAttr:$tied_operands ); let results = (outs Variadic:$results @@ -681,6 +693,7 @@ def FLOW_DispatchOp : FLOW_PureOp<"dispatch", [ "DispatchEntryOp":$entryPoint, "ValueRange":$workgroupCount, "TypeRange":$resultTypes, "ValueRange":$resultDims, "ValueRange":$operands, "ValueRange":$operandDims, + "ArrayRef":$tiedOperands, CArg<"ArrayRef", "{}">:$attributes)>, ]; @@ -695,8 +708,10 @@ def FLOW_DispatchOp : FLOW_PureOp<"dispatch", [ let assemblyFormat = [{ $entry_point `[` $workgroup_count `]` `` `(` $operands `)` attr-dict `:` - custom(type($operands), $operand_dims, - type($results), $result_dims) + custom(ref($operands), + type($operands), $operand_dims, + type($results), $result_dims, + $tied_operands) }]; let verifier = [{ return verifyDispatchOp(*this); }]; @@ -963,6 +978,11 @@ def FLOW_TensorUpdateOp : FLOW_PureOp<"tensor.update", [ AllTypesMatch<["target", "result"]>, AllElementTypesMatch<["update", "target", "result"]>, AttrSizedOperandSegments, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, ]> { let summary = [{updates a tensor with the contents of another tensor}]; @@ -976,7 +996,8 @@ def FLOW_TensorUpdateOp : FLOW_PureOp<"tensor.update", [ FLOW_ShapeDynamicDims:$target_dims, Variadic:$start_indices, FLOW_Tensor:$update, - FLOW_ShapeDynamicDims:$update_dims + FLOW_ShapeDynamicDims:$update_dims, + OptionalAttr:$tied_operands ); let results = (outs FLOW_Tensor:$result @@ -985,7 +1006,7 @@ def FLOW_TensorUpdateOp : FLOW_PureOp<"tensor.update", [ let assemblyFormat = [{ $update `,` $target `[` $start_indices `]` `:` type($update) (`{` $update_dims^ `}`)? `->` - type($result) (`{` $target_dims^ `}`)? + custom(type($result), $target_dims, $tied_operands) attr-dict-with-keyword }]; @@ -1033,18 +1054,21 @@ def FLOW_ExStreamFragmentOp : FLOW_PureOp<"ex.stream.fragment", [ IsolatedFromAbove, AttrSizedOperandSegments, DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, ]> { let summary = [{experimental op for defining formed stream regions}]; let description = [{ Represents a region where all of the dispatches are meant to target the - same execution stream. This will be replaced with a segmented verison. + same execution stream. This will be replaced with a segmented version in the + future that stitches the stream segments together. }]; let arguments = (ins Variadic:$operands, FLOW_ShapeDynamicDims:$operand_dims, - FLOW_ShapeDynamicDims:$result_dims + FLOW_ShapeDynamicDims:$result_dims, + OptionalAttr:$tied_operands ); let results = (outs Variadic:$results @@ -1054,12 +1078,15 @@ def FLOW_ExStreamFragmentOp : FLOW_PureOp<"ex.stream.fragment", [ let assemblyFormat = [{ `(` $operands `)` `:` - custom(type($operands), $operand_dims, - type($results), $result_dims) + custom(ref($operands), + type($operands), $operand_dims, + type($results), $result_dims, + $tied_operands) attr-dict-with-keyword `=` `\n` ` ` ` ` ` ` custom(ref(type($operands)), ref(type($results)), + ref($tied_operands), $body) }]; @@ -1068,6 +1095,7 @@ def FLOW_ExStreamFragmentOp : FLOW_PureOp<"ex.stream.fragment", [ OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$resultDims, "ValueRange":$operands, "ValueRange":$operandDims, + "ArrayRef":$tiedOperands, CArg<"ArrayRef", "{}">:$attributes)>, ]; diff --git a/iree/compiler/Dialect/Flow/IR/FlowTypes.cpp b/iree/compiler/Dialect/Flow/IR/FlowTypes.cpp index b9639178c50e..c916a84b64d1 100644 --- a/iree/compiler/Dialect/Flow/IR/FlowTypes.cpp +++ b/iree/compiler/Dialect/Flow/IR/FlowTypes.cpp @@ -25,6 +25,25 @@ namespace Flow { // Object types //===----------------------------------------------------------------------===// +// static +DispatchTensorType DispatchTensorType::get(TensorAccess access, + ArrayRef shape, + Type elementType) { + return Base::get(elementType.getContext(), static_cast(access), + shape, elementType); +} + +// static +DispatchTensorType DispatchTensorType::get(TensorAccess access, + TensorType tensorType) { + return DispatchTensorType::get(access, tensorType.getShape(), + tensorType.getElementType()); +} + +TensorAccess DispatchTensorType::getAccess() const { + return static_cast(static_cast(impl)->access); +} + Type DispatchTensorType::getElementType() const { return static_cast(impl)->elementType; } @@ -79,8 +98,8 @@ bool DispatchTensorType::hasStaticShape(ArrayRef shape) const { } LogicalResult DispatchTensorType::verify( - function_ref emitError, ArrayRef shape, - Type elementType) { + function_ref emitError, uint32_t access, + ArrayRef shape, Type elementType) { if (!isValidElementType(elementType)) { return emitError() << "dispatch tensor elements must be int or float type"; } @@ -93,17 +112,38 @@ LogicalResult DispatchTensorType::verify( template static T parseShapedType(DialectAsmParser &parser) { + StringRef accessStr; SmallVector shape; Type elementType; - if (failed(parser.parseLess()) || + if (failed(parser.parseLess()) || failed(parser.parseKeyword(&accessStr)) || + failed(parser.parseColon()) || failed(parser.parseDimensionList(shape, /*allowDynamic=*/true)) || failed(parser.parseType(elementType)) || failed(parser.parseGreater())) { return {}; } - return T::get(shape, elementType); + auto access = llvm::StringSwitch(accessStr) + .Case("readonly", TensorAccess::ReadOnly) + .Case("readwrite", TensorAccess::ReadWrite) + .Case("writeonly", TensorAccess::WriteOnly) + .Default(TensorAccess::ReadOnly); + return T::get(access, shape, elementType); } static void printShapedType(DispatchTensorType &type, DialectAsmPrinter &p) { + switch (type.getAccess()) { + case TensorAccess::ReadOnly: + p << "readonly"; + break; + case TensorAccess::ReadWrite: + p << "readwrite"; + break; + case TensorAccess::WriteOnly: + p << "writeonly"; + break; + default: + llvm_unreachable("unhandled access"); + } + p << ":"; for (int64_t dim : type.getShape()) { if (ShapedType::isDynamic(dim)) { p << '?'; @@ -116,61 +156,12 @@ static void printShapedType(DispatchTensorType &type, DialectAsmPrinter &p) { } // static -DispatchInputType DispatchInputType::get(ArrayRef shape, - Type elementType) { - return Base::get(elementType.getContext(), shape, elementType); -} - -// static -DispatchInputType DispatchInputType::getChecked(ArrayRef shape, - Type elementType, - Location location) { - return Base::getChecked(location, shape, elementType); -} - -// static -DispatchInputType DispatchInputType::get(TensorType tensorType) { - return DispatchInputType::get(tensorType.getShape(), - tensorType.getElementType()); -} - -// static -DispatchInputType DispatchInputType::parse(DialectAsmParser &parser) { - return parseShapedType(parser); -} - -void printType(DispatchInputType &type, DialectAsmPrinter &p) { - p << "dispatch.input<"; - printShapedType(type, p); - p << '>'; -} - -// static -DispatchOutputType DispatchOutputType::get(ArrayRef shape, - Type elementType) { - return Base::get(elementType.getContext(), shape, elementType); -} - -// static -DispatchOutputType DispatchOutputType::getChecked(ArrayRef shape, - Type elementType, - Location location) { - return Base::getChecked(location, shape, elementType); -} - -// static -DispatchOutputType DispatchOutputType::get(TensorType tensorType) { - return DispatchOutputType::get(tensorType.getShape(), - tensorType.getElementType()); -} - -// static -DispatchOutputType DispatchOutputType::parse(DialectAsmParser &parser) { - return parseShapedType(parser); +DispatchTensorType DispatchTensorType::parse(DialectAsmParser &parser) { + return parseShapedType(parser); } -void printType(DispatchOutputType &type, DialectAsmPrinter &p) { - p << "dispatch.output<"; +void printType(DispatchTensorType &type, DialectAsmPrinter &p) { + p << "dispatch.tensor<"; printShapedType(type, p); p << '>'; } diff --git a/iree/compiler/Dialect/Flow/IR/FlowTypes.h b/iree/compiler/Dialect/Flow/IR/FlowTypes.h index 59abc6f231b3..8a2fdf149715 100644 --- a/iree/compiler/Dialect/Flow/IR/FlowTypes.h +++ b/iree/compiler/Dialect/Flow/IR/FlowTypes.h @@ -43,15 +43,36 @@ namespace detail { struct DispatchTensorTypeStorage; } // namespace detail +enum class TensorAccess : uint32_t { + ReadOnly, + ReadWrite, + WriteOnly, +}; + // Blatantly ripped from ShapedType, because the closed type system means that // we can't extend it and reuse all of this. -class DispatchTensorType : public Type { +class DispatchTensorType + : public Type::TypeBase { public: using ImplType = detail::DispatchTensorTypeStorage; static constexpr int64_t kDynamicSize = -1; - using Type::Type; + using Base::Base; + + /// Get or create a new DispatchTensorType of the provided shape and + /// element type. Assumes the arguments define a well-formed + /// DispatchTensorType. + static DispatchTensorType get(TensorAccess access, ArrayRef shape, + Type elementType); + + static DispatchTensorType get(TensorAccess access, TensorType tensorType); + + static DispatchTensorType parse(DialectAsmParser &parser); + + /// Returns the allowed operations the tensor. + TensorAccess getAccess() const; /// Return the element type. Type getElementType() const; @@ -97,17 +118,15 @@ class DispatchTensorType : public Type { /// dimensions, given its `index` within the shape. unsigned getDynamicDimIndex(unsigned index) const; - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool classof(Type type); - /// Whether the given dimension size indicates a dynamic dimension. static constexpr bool isDynamic(int64_t dSize) { return dSize == kDynamicSize; } - /// Verify the construction of a vector type. + /// Verify the construction of a tensor type. static LogicalResult verify(function_ref emitError, - ArrayRef shape, Type elementType); + uint32_t access, ArrayRef shape, + Type elementType); /// Returns true of the given type can be used as an element of a vector type. /// In particular, vectors can consist of integer or float primitives. @@ -124,100 +143,41 @@ class DispatchTensorType : public Type { } }; -class DispatchInputType - : public Type::TypeBase { - public: - using Base::Base; - - /// Get or create a new DispatchInputType of the provided shape and element - /// type. Assumes the arguments define a well-formed DispatchInputType. - static DispatchInputType get(ArrayRef shape, Type elementType); - - /// Get or create a new DispatchInputType of the provided shape and element - /// type declared at the given, potentially unknown, location. If the - /// DispatchInputType defined by the arguments would be ill-formed, emit - /// errors and return nullptr-wrapping type. - static DispatchInputType getChecked(ArrayRef shape, Type elementType, - Location location); - static DispatchInputType getChecked( - function_ref emitError, ArrayRef shape, - Type elementType) { - return Base::getChecked(emitError, elementType.getContext(), shape, - elementType); - } - - static DispatchInputType get(TensorType tensorType); - - static DispatchInputType parse(DialectAsmParser &parser); -}; - -void printType(DispatchInputType &type, DialectAsmPrinter &p); - -class DispatchOutputType - : public Type::TypeBase { - public: - using Base::Base; - - /// Get or create a new DispatchOutputType of the provided shape and element - /// type. Assumes the arguments define a well-formed DispatchOutputType. - static DispatchOutputType get(ArrayRef shape, Type elementType); - - /// Get or create a new DispatchOutputType of the provided shape and element - /// type declared at the given, potentially unknown, location. If the - /// DispatchOutputType defined by the arguments would be ill-formed, emit - /// errors and return nullptr-wrapping type. - static DispatchOutputType getChecked(ArrayRef shape, - Type elementType, Location location); - static DispatchOutputType getChecked( - function_ref emitError, ArrayRef shape, - Type elementType) { - return Base::getChecked(emitError, elementType.getContext(), shape, - elementType); - } - - static DispatchOutputType get(TensorType tensorType); - - static DispatchOutputType parse(DialectAsmParser &parser); -}; - -void printType(DispatchOutputType &type, DialectAsmPrinter &p); - -inline bool DispatchTensorType::classof(Type type) { - return type.isa(); -} +void printType(DispatchTensorType &type, DialectAsmPrinter &p); namespace detail { struct DispatchTensorTypeStorage : public TypeStorage { - DispatchTensorTypeStorage(unsigned shapeSize, Type elementTy, + DispatchTensorTypeStorage(uint32_t access, unsigned shapeSize, Type elementTy, const int64_t *shapeElements) - : shapeElements(shapeElements), + : access(access), + shapeElements(shapeElements), shapeSize(shapeSize), elementType(elementTy) {} /// The hash key used for uniquing. - using KeyTy = std::pair, Type>; + using KeyTy = std::tuple, Type>; bool operator==(const KeyTy &key) const { - return key == KeyTy(getShape(), elementType); + return key == KeyTy(access, getShape(), elementType); } /// Construction. static DispatchTensorTypeStorage *construct(TypeStorageAllocator &allocator, const KeyTy &key) { // Copy the shape into the bump pointer. - ArrayRef shape = allocator.copyInto(key.first); + ArrayRef shape = allocator.copyInto(std::get<1>(key)); // Initialize the memory using placement new. return new (allocator.allocate()) - DispatchTensorTypeStorage(shape.size(), key.second, shape.data()); + DispatchTensorTypeStorage(std::get<0>(key), shape.size(), + std::get<2>(key), shape.data()); } ArrayRef getShape() const { return ArrayRef(shapeElements, shapeSize); } + uint32_t access; const int64_t *shapeElements; unsigned shapeSize; Type elementType; diff --git a/iree/compiler/Dialect/Flow/IR/test/dispatch_ops.mlir b/iree/compiler/Dialect/Flow/IR/test/dispatch_ops.mlir index 8861704fca00..7cb22d0673c3 100644 --- a/iree/compiler/Dialect/Flow/IR/test/dispatch_ops.mlir +++ b/iree/compiler/Dialect/Flow/IR/test/dispatch_ops.mlir @@ -1,5 +1,3 @@ -// Tests printing and parsing of dispatch ops. - // RUN: iree-opt -split-input-file %s | iree-opt -split-input-file | IreeFileCheck %s flow.executable @ex0 { @@ -15,7 +13,33 @@ flow.executable @ex0 { func @dispatch(%arg0 : tensor<4xf32>) -> tensor<4xf32> { // CHECK: %[[CST:.+]] = constant %cst = constant 4 : index - // CHECK: %0 = flow.dispatch @ex0::@dispatch_fn[%[[CST]]](%[[CST]], %arg0) : (index, tensor<4xf32>) -> (tensor<4xf32>) - %0 = flow.dispatch @ex0::@dispatch_fn[%cst](%cst, %arg0) : (index, tensor<4xf32>) -> (tensor<4xf32>) + // CHECK: %0 = flow.dispatch @ex0::@dispatch_fn[%[[CST]]](%[[CST]], %arg0) : (index, tensor<4xf32>) -> tensor<4xf32> + %0 = flow.dispatch @ex0::@dispatch_fn[%cst](%cst, %arg0) : (index, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } + +// ----- + +// CHECK-LABEL: @inplaceDispatch +func @inplaceDispatch(%arg0 : tensor<4xf32>, %arg1 : tensor<8xf32>) -> (tensor<4xf32>, tensor<8xf32>) { + // CHECK: %[[CST:.+]] = constant + %cst = constant 4 : index + // CHECK: %0:2 = flow.dispatch @ex0::@dispatch_fn[%[[CST]]](%[[CST]], %arg0, %arg1) : (index, tensor<4xf32>, tensor<8xf32>) -> (%arg0, %arg1) + %0, %1 = flow.dispatch @ex0::@dispatch_fn[%cst](%cst, %arg0, %arg1) : (index, tensor<4xf32>, tensor<8xf32>) -> (%arg0, %arg1) + return %0, %1 : tensor<4xf32>, tensor<8xf32> +} + +// ----- + +// CHECK-LABEL: @inplaceDynamicDispatch +func @inplaceDynamicDispatch(%arg0 : tensor<4x?xf32>, %arg1 : tensor<8x?xf32>) -> (tensor<4x?xf32>, tensor<8x?xf32>) { + // CHECK-DAG: %[[CST:.+]] = constant 4 + %cst = constant 4 : index + // CHECK-DAG: %[[DIM0:.+]] = constant 100 + %dim0 = constant 100 : index + // CHECK-DAG: %[[DIM1:.+]] = constant 200 + %dim1 = constant 200 : index + // CHECK: %0:2 = flow.dispatch @ex0::@dispatch_fn[%[[CST]]](%[[CST]], %arg0, %arg1) : (index, tensor<4x?xf32>{%[[DIM0]]}, tensor<8x?xf32>{%[[DIM1]]}) -> (%arg0, %arg1) + %0, %1 = flow.dispatch @ex0::@dispatch_fn[%cst](%cst, %arg0, %arg1) : (index, tensor<4x?xf32>{%dim0}, tensor<8x?xf32>{%dim1}) -> (%arg0, %arg1) + return %0, %1 : tensor<4x?xf32>, tensor<8x?xf32> +} diff --git a/iree/compiler/Dialect/Flow/IR/test/dispatch_workgroups.mlir b/iree/compiler/Dialect/Flow/IR/test/dispatch_workgroups.mlir index 3f42360a34e3..054292c0c503 100644 --- a/iree/compiler/Dialect/Flow/IR/test/dispatch_workgroups.mlir +++ b/iree/compiler/Dialect/Flow/IR/test/dispatch_workgroups.mlir @@ -14,12 +14,12 @@ func @complexWorkgroupsUsage( // CHECK: %[[OUTER_RET0:.+]] = flow.dispatch.workgroups[ // CHECK-SAME: %[[WORKGROUP_COUNT_X]], %[[WORKGROUP_COUNT_Y]] // CHECK-SAME: ](%[[ARG0]], %[[ARG1]]) - // CHECK-SAME: : (tensor{%c128}, index) -> (tensor<4x?xf32>{%c128}) = + // CHECK-SAME: : (tensor{%c128}, index) -> tensor<4x?xf32>{%c128} = %0 = flow.dispatch.workgroups[%x, %y](%arg0, %arg1) : (tensor{%c128}, index) -> tensor<4x?xf32>{%c128} = - // CHECK-NEXT: (%[[INNER_ARG0:.+]]: !flow.dispatch.input + // CHECK-NEXT: (%[[INNER_ARG0:.+]]: !flow.dispatch.tensor // CHECK-SAME: %[[INNER_ARG1:.+]]: index - // CHECK-SAME: %[[INNER_RET0:.+]]: !flow.dispatch.output<4x?xf32>) { - (%arg0_capture: !flow.dispatch.input, %arg1_capture: index, %ret0: !flow.dispatch.output<4x?xf32>) { + // CHECK-SAME: %[[INNER_RET0:.+]]: !flow.dispatch.tensor) { + (%arg0_capture: !flow.dispatch.tensor, %arg1_capture: index, %ret0: !flow.dispatch.tensor) { // Query symbolic workgroup info: @@ -41,8 +41,8 @@ func @complexWorkgroupsUsage( // Query shapes directly from IO (static dims will fold): - // CHECK: %[[ARG0_SHAPE:.+]] = flow.dispatch.shape %[[INNER_ARG0]] : !flow.dispatch.input -> !shapex.ranked_shape<[?,4]> - %arg0_shape = flow.dispatch.shape %arg0_capture : !flow.dispatch.input -> !shapex.ranked_shape<[?,4]> + // CHECK: %[[ARG0_SHAPE:.+]] = flow.dispatch.shape %[[INNER_ARG0]] : !flow.dispatch.tensor -> !shapex.ranked_shape<[?,4]> + %arg0_shape = flow.dispatch.shape %arg0_capture : !flow.dispatch.tensor -> !shapex.ranked_shape<[?,4]> // CHECK-DAG: %[[ARG0_DIM0:.+]] = shapex.ranked_dim %[[ARG0_SHAPE]][0] : !shapex.ranked_shape<[?,4]> -> index %arg0_dim0 = shapex.ranked_dim %arg0_shape[0] : !shapex.ranked_shape<[?,4]> -> index // CHECK-DAG: %[[ARG0_DIM1:.+]] = shapex.ranked_dim %[[ARG0_SHAPE]][1] : !shapex.ranked_shape<[?,4]> -> index @@ -50,8 +50,8 @@ func @complexWorkgroupsUsage( // CHECK-NEXT: "test.sink"(%[[ARG0_DIM0]], %[[ARG0_DIM1]]) "test.sink"(%arg0_dim0, %arg0_dim1) : (index, index) -> () - // CHECK: %[[RET0_SHAPE:.+]] = flow.dispatch.shape %[[INNER_RET0]] : !flow.dispatch.output<4x?xf32> -> !shapex.ranked_shape<[4,?]> - %ret0_shape = flow.dispatch.shape %ret0 : !flow.dispatch.output<4x?xf32> -> !shapex.ranked_shape<[4,?]> + // CHECK: %[[RET0_SHAPE:.+]] = flow.dispatch.shape %[[INNER_RET0]] : !flow.dispatch.tensor -> !shapex.ranked_shape<[4,?]> + %ret0_shape = flow.dispatch.shape %ret0 : !flow.dispatch.tensor -> !shapex.ranked_shape<[4,?]> // CHECK-DAG: %[[RET0_DIM0:.+]] = shapex.ranked_dim %[[RET0_SHAPE]][0] : !shapex.ranked_shape<[4,?]> -> index %ret0_dim0 = shapex.ranked_dim %ret0_shape[0] : !shapex.ranked_shape<[4,?]> -> index // CHECK-DAG: %[[RET0_DIM1:.+]] = shapex.ranked_dim %[[RET0_SHAPE]][1] : !shapex.ranked_shape<[4,?]> -> index @@ -61,8 +61,8 @@ func @complexWorkgroupsUsage( // Load tensors (optional offsets/sizes/strides): - // CHECK: %[[ARG0_VALUE:.+]] = flow.dispatch.input.load %[[INNER_ARG0]] : !flow.dispatch.input -> tensor - %arg0_value = flow.dispatch.input.load %arg0_capture : !flow.dispatch.input -> tensor + // CHECK: %[[ARG0_VALUE:.+]] = flow.dispatch.tensor.load %[[INNER_ARG0]] : !flow.dispatch.tensor -> tensor + %arg0_value = flow.dispatch.tensor.load %arg0_capture : !flow.dispatch.tensor -> tensor // CHECK-NEXT: %[[ARG0_SHAPE_INDIRECT:.+]] = shapex.get_ranked_shape %[[ARG0_VALUE]] : tensor -> !shapex.ranked_shape<[?,4]> %arg0_shape_indirect = shapex.get_ranked_shape %arg0_value : tensor -> !shapex.ranked_shape<[?,4]> @@ -73,8 +73,8 @@ func @complexWorkgroupsUsage( // Store tensors (optional offsets/sizes/strides): - // CHECK: flow.dispatch.output.store %[[RET0_VALUE]], %[[INNER_RET0]] : tensor<4x?xf32> -> !flow.dispatch.output<4x?xf32> - flow.dispatch.output.store %ret0_value, %ret0 : tensor<4x?xf32> -> !flow.dispatch.output<4x?xf32> + // CHECK: flow.dispatch.tensor.store %[[RET0_VALUE]], %[[INNER_RET0]] : tensor<4x?xf32> -> !flow.dispatch.tensor + flow.dispatch.tensor.store %ret0_value, %ret0 : tensor<4x?xf32> -> !flow.dispatch.tensor // CHECK-NEXT: flow.return flow.return @@ -82,3 +82,35 @@ func @complexWorkgroupsUsage( // CHECK: return %[[OUTER_RET0]] : tensor<4x?xf32> return %0 : tensor<4x?xf32> } + +// ----- + +// CHECK-LABEL: @inplaceDispatch +func @inplaceDispatch( + // CHECK-SAME: %[[ARG0:.+]]: tensor + %arg0: tensor, + // CHECK-SAME: %[[ARG1:.+]]: index + %arg1: index) -> tensor { + %c128 = constant 128 : index + // CHECK-DAG: %[[WORKGROUP_COUNT_X:.+]] = constant 100 + %x = constant 100 : index + // CHECK-DAG: %[[WORKGROUP_COUNT_Y:.+]] = constant 50 + %y = constant 50 : index + // CHECK: %[[OUTER_RET0:.+]] = flow.dispatch.workgroups[ + // CHECK-SAME: %[[WORKGROUP_COUNT_X]], %[[WORKGROUP_COUNT_Y]] + // CHECK-SAME: ](%[[ARG0]], %[[ARG1]]) + // CHECK-SAME: : (tensor{%c128}, index) -> %arg0 = + %0 = flow.dispatch.workgroups[%x, %y](%arg0, %arg1) : (tensor{%c128}, index) -> %arg0 = + // CHECK-NEXT: (%[[INNER_ARG0:.+]]: !flow.dispatch.tensor + // CHECK-SAME: %[[INNER_ARG1:.+]]: index) { + (%arg0_capture: !flow.dispatch.tensor, %arg1_capture: index) { + // CHECK: %[[VALUE:.+]] = flow.dispatch.tensor.load %[[INNER_ARG0]] : !flow.dispatch.tensor -> tensor + %t = flow.dispatch.tensor.load %arg0_capture : !flow.dispatch.tensor -> tensor + // CHECK: flow.dispatch.tensor.store %[[VALUE]], %[[INNER_ARG0]] : tensor -> !flow.dispatch.tensor + flow.dispatch.tensor.store %t, %arg0_capture : tensor -> !flow.dispatch.tensor + // CHECK-NEXT: flow.return + flow.return + } + // CHECK: return %[[OUTER_RET0]] : tensor + return %0 : tensor +} diff --git a/iree/compiler/Dialect/Flow/IR/test/dispatch_workgroups_folding.mlir b/iree/compiler/Dialect/Flow/IR/test/dispatch_workgroups_folding.mlir index 906c5f672dda..7cedb3cee5de 100644 --- a/iree/compiler/Dialect/Flow/IR/test/dispatch_workgroups_folding.mlir +++ b/iree/compiler/Dialect/Flow/IR/test/dispatch_workgroups_folding.mlir @@ -6,16 +6,16 @@ func @workgroupStaticShapeDims(%arg0 : tensor) -> tensor<4x?xf32> { %x = constant 100 : index %y = constant 50 : index // CHECK: flow.dispatch.workgroups - %0 = flow.dispatch.workgroups[%x, %y](%arg0) : (tensor{%c128}) -> tensor<4x?xf32>{%c128} = ( - // CHECK-NEXT: (%[[ARG0:.+]]: !flow.dispatch.input, - %arg0_capture: !flow.dispatch.input, - // CHECK-SAME: %[[RET0:.+]]: !flow.dispatch.output<4x?xf32>) - %ret0: !flow.dispatch.output<4x?xf32> + %0 = flow.dispatch.workgroups[%x, %y](%arg0) : (tensor{%c128}) -> (tensor<4x?xf32>{%c128}) = ( + // CHECK-NEXT: (%[[ARG0:.+]]: !flow.dispatch.tensor, + %arg0_capture: !flow.dispatch.tensor, + // CHECK-SAME: %[[RET0:.+]]: !flow.dispatch.tensor) + %ret0: !flow.dispatch.tensor ) { // CHECK: %[[DIM_4:.+]] = constant 4 : index // CHECK: %[[ARG0_SHAPE:.+]] = flow.dispatch.shape %[[ARG0]] - %arg0_shape = flow.dispatch.shape %arg0_capture : !flow.dispatch.input -> !shapex.ranked_shape<[?,4]> + %arg0_shape = flow.dispatch.shape %arg0_capture : !flow.dispatch.tensor -> !shapex.ranked_shape<[?,4]> // CHECK: %[[ARG0_DIM0:.+]] = shapex.ranked_dim %[[ARG0_SHAPE]][0] %arg0_dim0 = shapex.ranked_dim %arg0_shape[0] : !shapex.ranked_shape<[?,4]> -> index %arg0_dim1 = shapex.ranked_dim %arg0_shape[1] : !shapex.ranked_shape<[?,4]> -> index @@ -23,7 +23,7 @@ func @workgroupStaticShapeDims(%arg0 : tensor) -> tensor<4x?xf32> { "test.sink"(%arg0_dim0, %arg0_dim1) : (index, index) -> () // CHECK: %[[RET0_SHAPE:.+]] = flow.dispatch.shape %[[RET0]] - %ret0_shape = flow.dispatch.shape %ret0 : !flow.dispatch.output<4x?xf32> -> !shapex.ranked_shape<[4,?]> + %ret0_shape = flow.dispatch.shape %ret0 : !flow.dispatch.tensor -> !shapex.ranked_shape<[4,?]> %ret0_dim0 = shapex.ranked_dim %ret0_shape[0] : !shapex.ranked_shape<[4,?]> -> index // CHECK: %[[RET0_DIM1:.+]] = shapex.ranked_dim %[[RET0_SHAPE]][1] %ret0_dim1 = shapex.ranked_dim %ret0_shape[1] : !shapex.ranked_shape<[4,?]> -> index @@ -43,9 +43,9 @@ func @workgroupRankFolding(%arg0 : tensor) -> tensor<4x?xf32> { %x = constant 100 : index %y = constant 50 : index // CHECK: flow.dispatch.workgroups - %0 = flow.dispatch.workgroups[%x, %y](%arg0) : (tensor{%c128}) -> tensor<4x?xf32>{%c128} = ( - %arg0_capture: !flow.dispatch.input, - %ret0: !flow.dispatch.output<4x?xf32> + %0 = flow.dispatch.workgroups[%x, %y](%arg0) : (tensor{%c128}) -> (tensor<4x?xf32>{%c128}) = ( + %arg0_capture: !flow.dispatch.tensor, + %ret0: !flow.dispatch.tensor ) { // CHECK: %[[RANK:.+]] = constant 2 : index %workgroup_rank = flow.dispatch.workgroup.rank : index @@ -59,12 +59,12 @@ func @workgroupRankFolding(%arg0 : tensor) -> tensor<4x?xf32> { // ----- // CHECK-LABEL: @convertDimOfDispatchInputLoadToDispatchShape -// CHECK-SAME: %[[ARG:.*]]: !flow.dispatch.input) { -func @convertDimOfDispatchInputLoadToDispatchShape(%arg0: !flow.dispatch.input) { +// CHECK-SAME: %[[ARG:.*]]: !flow.dispatch.tensor) { +func @convertDimOfDispatchInputLoadToDispatchShape(%arg0: !flow.dispatch.tensor) { // CHECK-NEXT: %[[RANKED_SHAPE:.*]] = flow.dispatch.shape %[[ARG]] // CHECK-NEXT: %[[DIM:.*]] = shapex.ranked_dim %[[RANKED_SHAPE]][0] // CHECK-NEXT: "test.sink"(%[[DIM]]) : (index) -> () - %tensor = flow.dispatch.input.load %arg0 : !flow.dispatch.input -> tensor + %tensor = flow.dispatch.tensor.load %arg0 : !flow.dispatch.tensor -> tensor %c0 = constant 0 : index %dim = dim %tensor, %c0 : tensor "test.sink"(%dim) : (index) -> () diff --git a/iree/compiler/Dialect/Flow/IR/test/stream_folding.mlir b/iree/compiler/Dialect/Flow/IR/test/stream_folding.mlir index e68847663585..81ca6be3ebca 100644 --- a/iree/compiler/Dialect/Flow/IR/test/stream_folding.mlir +++ b/iree/compiler/Dialect/Flow/IR/test/stream_folding.mlir @@ -1,5 +1,3 @@ -// Tests folding and canonicalization of stream ops. - // RUN: iree-opt -split-input-file -canonicalize %s | iree-opt -split-input-file | IreeFileCheck %s // CHECK-LABEL: func @inlineConstant @@ -23,8 +21,12 @@ func @removeUnusedCapture(%arg: index) -> index { %unused = constant 5 : index // CHECK: flow.ex.stream.fragment(%[[ARG]]) %0 = flow.ex.stream.fragment(%arg, %unused) : (index, index) -> index = + // CHECK-NEXT: (%[[INNER_ARG:.+]]: index) -> index { (%arg0: index, %arg1: index) -> index { - flow.return %arg0 : index + // CHECK-NEXT: %[[T:.+]] = addi %[[INNER_ARG]], %[[INNER_ARG]] + %t = addi %arg0, %arg0 : index + // CHECK-NEXT: flow.return %[[T]] + flow.return %t : index } return %0 : index } @@ -37,7 +39,8 @@ func @removeUnusedDupCapture(%arg: index) -> index { // CHECK: flow.ex.stream.fragment(%[[ARG]]) %0 = flow.ex.stream.fragment(%arg, %arg) : (index, index) -> index = (%arg0: index, %arg1: index) -> index { - flow.return %arg1 : index + %t = addi %arg0, %arg0 : index + flow.return %t : index } return %0 : index } @@ -49,8 +52,56 @@ func @removeUnusedDupCapture(%arg: index) -> index { func @removeUnusedResult(%arg0: index, %arg1: index) -> index { // CHECK: flow.ex.stream.fragment(%[[ARG1]]) %0:2 = flow.ex.stream.fragment(%arg0, %arg1) : (index, index) -> (index, index) = - (%arg0: index, %arg1: index) -> (index, index) { - flow.return %arg1, %arg0 : index, index + (%unused: index, %arg1: index) -> (index, index) { + %t = addi %arg1, %arg1 : index + flow.return %t, %unused : index, index } return %0#0 : index } + +// ----- + +// CHECK-LABEL: func @removeUnusedDynamicResult +// CHECK-SAME: (%[[ARG0:.+]]: tensor<4x?xf32>, %[[DIM0:.+]]: index, +// CHECK-SAME: %[[ARG1:.+]]: tensor<8x?xf32>, %[[DIM1:.+]]: index) +func @removeUnusedDynamicResult(%arg0: tensor<4x?xf32>, %dim0: index, + %arg1: tensor<8x?xf32>, %dim1: index) -> tensor<8x?xf32> { + // CHECK: flow.ex.stream.fragment(%[[ARG1]]) : + %0:2 = flow.ex.stream.fragment(%arg0, %arg1) : + // CHECK-SAME: (tensor<8x?xf32>{%[[DIM1]]}) -> %[[ARG1]] = + (tensor<4x?xf32>{%dim0}, tensor<8x?xf32>{%dim1}) -> (%arg0, %arg1) = + // CHECK-NEXT: (%[[INNER_ARG:.+]]: tensor<8x?xf32>) -> tensor<8x?xf32> + (%unused: tensor<4x?xf32>, %arg1: tensor<8x?xf32>) -> (tensor<4x?xf32>, tensor<8x?xf32>) { + // CHECK-NEXT: flow.return %[[INNER_ARG]] : tensor<8x?xf32> + flow.return %unused, %arg1 : tensor<4x?xf32>, tensor<8x?xf32> + } + return %0#1 : tensor<8x?xf32> +} + +// ----- + +// Testing inserted clones: a clone here is required as %stream_target is used +// after it is updated. + +// CHECK-LABEL: @dynamicUpdateSliceImmutability +func @dynamicUpdateSliceImmutability( + %target: tensor<2x4xi32>, %update: tensor<1x1xi32>) -> tensor<2x4xi32> { + // CHECK: %[[RET:.+]] = flow.ex.stream.fragment + %ret = flow.ex.stream.fragment(%target, %update) : + (tensor<2x4xi32>, tensor<1x1xi32>) -> tensor<2x4xi32> = + // CHECK-NEXT: (%[[TARGET:.+]]: tensor<2x4xi32>, %[[UPDATE:.+]]: tensor<1x1xi32>) + (%stream_target: tensor<2x4xi32>, %stream_update: tensor<1x1xi32>) -> tensor<2x4xi32> { + %start0 = constant 0 : index + %start1 = constant 1 : index + %workload = constant 8 : index + // CHECK: %[[TARGET_CLONE:.+]] = flow.tensor.clone %[[TARGET]] : tensor<2x4xi32> + // CHECK-NEXT: %[[UPDATED:.+]] = flow.tensor.update %[[UPDATE]], %[[TARGET]] + %t0 = flow.tensor.update %stream_update, %stream_target[%start0, %start1] : tensor<1x1xi32> -> tensor<2x4xi32> + // CHECK-NEXT: %[[RETURN:.+]] = flow.dispatch @ex::@entry[%c8](%[[TARGET_CLONE]], %[[UPDATED]]) + %t1 = flow.dispatch @ex::@entry[%workload](%stream_target, %t0) : (tensor<2x4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> + // CHECK-NEXT: flow.return %[[RETURN]] + flow.return %t1 : tensor<2x4xi32> + } + // CHECK: return %[[RET]] + return %ret : tensor<2x4xi32> +} diff --git a/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir b/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir index 570b255855d7..cd9cbe9518dd 100644 --- a/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir +++ b/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir @@ -105,15 +105,6 @@ func @cloneConst() -> tensor<4xi32> { // ----- -// CHECK-LABEL: @cloneDynamic -func @cloneDynamic(%arg0 : tensor<4xi32>) -> tensor<4xi32> { - %0 = flow.tensor.clone %arg0 : tensor<4xi32> - // CHECK-NEXT: return %arg0 - return %0 : tensor<4xi32> -} - -// ----- - // CHECK-LABEL: @sliceConst0D func @sliceConst0D() -> tensor { %0 = constant dense<0> : tensor diff --git a/iree/compiler/Dialect/Flow/IR/test/types.mlir b/iree/compiler/Dialect/Flow/IR/test/types.mlir index dcae0a2362e1..f64620b82a6e 100644 --- a/iree/compiler/Dialect/Flow/IR/test/types.mlir +++ b/iree/compiler/Dialect/Flow/IR/test/types.mlir @@ -2,26 +2,26 @@ // CHECK-LABEL: @dispatchTypes func @dispatchTypes( - // CHECK-SAME: %arg0: !flow.dispatch.input - %arg0: !flow.dispatch.input, - // CHECK-SAME: %arg1: !flow.dispatch.input<4x4xf32> - %arg1: !flow.dispatch.input<4x4xf32>, - // CHECK-SAME: %arg2: !flow.dispatch.input<1x2x3x4x5x6xf32> - %arg2: !flow.dispatch.input<1x2x3x4x5x6xf32>, - // CHECK-SAME: %arg3: !flow.dispatch.input - %arg3: !flow.dispatch.input, - // CHECK-SAME: %arg4: !flow.dispatch.input<1x?x3xf32> - %arg4: !flow.dispatch.input<1x?x3xf32>, - // CHECK-SAME: %arg5: !flow.dispatch.output - %arg5: !flow.dispatch.output, - // CHECK-SAME: %arg6: !flow.dispatch.output<4x4xf32> - %arg6: !flow.dispatch.output<4x4xf32>, - // CHECK-SAME: %arg7: !flow.dispatch.output<1x2x3x4x5x6xf32> - %arg7: !flow.dispatch.output<1x2x3x4x5x6xf32>, - // CHECK-SAME: %arg8: !flow.dispatch.output - %arg8: !flow.dispatch.output, - // CHECK-SAME: %arg9: !flow.dispatch.output<1x?x3xf32> - %arg9: !flow.dispatch.output<1x?x3xf32> + // CHECK-SAME: %arg0: !flow.dispatch.tensor + %arg0: !flow.dispatch.tensor, + // CHECK-SAME: %arg1: !flow.dispatch.tensor + %arg1: !flow.dispatch.tensor, + // CHECK-SAME: %arg2: !flow.dispatch.tensor + %arg2: !flow.dispatch.tensor, + // CHECK-SAME: %arg3: !flow.dispatch.tensor + %arg3: !flow.dispatch.tensor, + // CHECK-SAME: %arg4: !flow.dispatch.tensor + %arg4: !flow.dispatch.tensor, + // CHECK-SAME: %arg5: !flow.dispatch.tensor + %arg5: !flow.dispatch.tensor, + // CHECK-SAME: %arg6: !flow.dispatch.tensor + %arg6: !flow.dispatch.tensor, + // CHECK-SAME: %arg7: !flow.dispatch.tensor + %arg7: !flow.dispatch.tensor, + // CHECK-SAME: %arg8: !flow.dispatch.tensor + %arg8: !flow.dispatch.tensor, + // CHECK-SAME: %arg9: !flow.dispatch.tensor + %arg9: !flow.dispatch.tensor ) { return } diff --git a/iree/compiler/Dialect/Flow/Transforms/CreateBenchmarkFuncs.cpp b/iree/compiler/Dialect/Flow/Transforms/CreateBenchmarkFuncs.cpp index 656158c8da20..711a0cd28977 100644 --- a/iree/compiler/Dialect/Flow/Transforms/CreateBenchmarkFuncs.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/CreateBenchmarkFuncs.cpp @@ -69,7 +69,8 @@ class CreateBenchmarkFuncs auto dummyWorkload = blockBuilder.create(loc, 0); auto dispatchOp = blockBuilder.create( loc, dispatchEntryOp, ValueRange{dummyWorkload}, - funcType.getResults(), ValueRange{}, args, ValueRange{}); + funcType.getResults(), ValueRange{}, args, ValueRange{}, + ArrayRef{}); blockBuilder.create(loc, dispatchOp.getResults()); } } diff --git a/iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.cpp b/iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.cpp index 0835445df6f5..57f3d4343069 100644 --- a/iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.cpp @@ -221,7 +221,7 @@ static Value isADestructiveUpdatePattern(Value tensor, static LogicalResult propagateSubTensorOp(OpBuilder &b, SubTensorOp op) { OpBuilder::InsertionGuard g(b); b.setInsertionPoint(op); - auto loadOp = op.source().getDefiningOp(); + auto loadOp = op.source().getDefiningOp(); if (!loadOp) { BlockArgument val = op.source().dyn_cast(); while (val) { @@ -230,7 +230,7 @@ static LogicalResult propagateSubTensorOp(OpBuilder &b, SubTensorOp op) { if (!forOp) return failure(); unsigned idx = val.getArgNumber() - 1; // accounting for IV arg. Value iterOperand = *(forOp.getIterOperands().begin() + idx); - loadOp = iterOperand.getDefiningOp(); + loadOp = iterOperand.getDefiningOp(); val = iterOperand.dyn_cast(); } } @@ -243,7 +243,7 @@ static LogicalResult propagateSubTensorOp(OpBuilder &b, SubTensorOp op) { sizes.push_back(r.size); strides.push_back(r.stride); } - Value loaded = b.create( + Value loaded = b.create( op.getLoc(), op.getResult().getType(), loadOp.source(), offsets, sizes, strides); op.getResult().replaceAllUsesWith(loaded); @@ -277,7 +277,7 @@ static LogicalResult rewriteSubTensorInsertInPlace(OpBuilder &b, sizes.push_back(r.size); strides.push_back(r.stride); } - b.create(op.getLoc(), op.source(), target, + b.create(op.getLoc(), op.source(), target, offsets, sizes, strides); return success(); } @@ -418,7 +418,7 @@ static LogicalResult rewriteDestructiveUpdateInPlace(OpBuilder &b, Value v, // Reload the value produced inplace right after the inplace update. OpBuilder::InsertionGuard g(b); b.setInsertionPointAfter(outermostProducingOp); - Value newLoad = b.create( + Value newLoad = b.create( outermostProducingOp->getLoc(), v.getType(), target); // TODO(nicolasvasilache): this brutally replaces all uses by the result of // this load. In practice we may want more recompute and we may have lost @@ -435,8 +435,8 @@ static LogicalResult rewriteDestructiveUpdateInPlace(OpBuilder &b, Value v, // consecutive ops". Probably better to wait until core alias analysis is // upstreamed. // TODO(nicolasvasilache): interfaces. -static bool hasInterleavedAliases(IREE::Flow::DispatchInputLoadOp loadOp, - IREE::Flow::DispatchOutputStoreOp storeOp) { +static bool hasInterleavedAliases(IREE::Flow::DispatchTensorLoadOp loadOp, + IREE::Flow::DispatchTensorStoreOp storeOp) { Block *bLoad = loadOp.getOperation()->getBlock(); Block *bStore = loadOp.getOperation()->getBlock(); if (!isa(bLoad->getParentOp()) || @@ -461,7 +461,7 @@ LogicalResult rewriteLinalgDestructiveUpdates( // For each tensor store op, look for destructive updates and replace the // destructive pattern by a custom inplace update pattern. bool fail = dispatchOp - .walk([&](IREE::Flow::DispatchOutputStoreOp op) { + .walk([&](IREE::Flow::DispatchTensorStoreOp op) { if (failed(rewriteDestructiveUpdateInPlace(b, op.value(), op.target()))) { return WalkResult::interrupt(); @@ -472,8 +472,8 @@ LogicalResult rewriteLinalgDestructiveUpdates( if (fail) return failure(); // For each tensor store op, redundant load/store optimization. - dispatchOp.walk([&](IREE::Flow::DispatchOutputStoreOp storeOp) { - auto loadOp = dyn_cast_or_null( + dispatchOp.walk([&](IREE::Flow::DispatchTensorStoreOp storeOp) { + auto loadOp = dyn_cast_or_null( storeOp.value().getDefiningOp()); // Bail if there exists an interleaved aliasing. if (!loadOp || hasInterleavedAliases(loadOp, storeOp)) { diff --git a/iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.h b/iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.h index 106d3c981df3..7c747a074122 100644 --- a/iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.h +++ b/iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.h @@ -52,27 +52,30 @@ namespace Flow { // // The following destructive update patterns are rewritten. // -// Coming from an `Flow::DispatchInputLoadOp` +// Coming from an `Flow::DispatchTensorLoadOp` // ========================================== // ``` -// %0 = flow.dispatch.input.load %a : !flow.dispatch.input<...> -> tensor<...> +// %0 = flow.dispatch.tensor.load %a : !flow.dispatch.tensor -> +// tensor<...> // ... // %1 = destructive_update(%0) // ... -// use_of(%1) // e.g. flow.dispatch.output.store %1, %b : -// // tensor<...> -> !flow.dispatch.output<...> +// use_of(%1) // e.g. flow.dispatch.tensor.store %1, %b : +// // tensor<...> -> !flow.dispatch.tensor // ``` // is rewritten into: // ``` -// %0 = flow.dispatch.input.load %a : !flow.dispatch.input<...> -> tensor<...> +// %0 = flow.dispatch.tensor.load %a : !flow.dispatch.tensor -> +// tensor<...> // ... -// inplace_update(%0, %out) //e.g. flow.dispatch.output.store %subtensor, %b, +// inplace_update(%0, %out) //e.g. flow.dispatch.tensor.store %subtensor, %b, // // offsets = ..., sizes = ..., strides = ... : -// // tensor<...> -> !flow.dispatch.output<...> +// // tensor<...> -> +// !flow.dispatch.tensor // %2 = flow.dispatch.output.load %b // ... -// use_of(%2) // e.g. flow.dispatch.output.store %2, %b : -// // tensor<...> -> !flow.dispatch.output<...> +// use_of(%2) // e.g. flow.dispatch.tensor.store %2, %b : +// // tensor<...> -> !flow.dispatch.tensor // ``` // // This is a typical pattern that appears after tiling Linalg ops on tensors @@ -84,8 +87,8 @@ namespace Flow { // ``` // %2 = flow.dispatch.output.load %b // ... -// flow.dispatch.output.store %2, %b : -// tensor<...> -> !flow.dispatch.output<...> +// flow.dispatch.tensor.store %2, %b : +// tensor<...> -> !flow.dispatch.tensor // ``` // is elided. LogicalResult rewriteLinalgDestructiveUpdates( diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp index f84b74350f17..8f8a428f0742 100644 --- a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp @@ -226,7 +226,8 @@ buildOperandLessFlowDispatchWorkgroupOp(PatternRewriter &rewriter, Location loc, auto dispatchOp = rewriter.create( loc, count, op->getResultTypes(), /*result_dims=*/ValueRange{}, /*operands=*/ValueRange{}, - /*operand_dims=*/ValueRange{}); + /*operand_dims=*/ValueRange{}, + /*tied_operands=*/ArrayRef{}); Region ®ion = dispatchOp.body(); Block *block = ®ion.front(); Operation *clonedOp; @@ -237,7 +238,7 @@ buildOperandLessFlowDispatchWorkgroupOp(PatternRewriter &rewriter, Location loc, for (auto it : llvm::zip(clonedOp->getResults(), dispatchOp.body().getArguments().take_back( clonedOp->getNumResults()))) { - rewriter.create( + rewriter.create( loc, std::get<0>(it), std::get<1>(it), llvm::None, llvm::None, llvm::None); } @@ -387,18 +388,18 @@ static LogicalResult legalizeDispatchWorkgroupOperands( // Replace valuesDefinedAbove by new BB args (including the op's operands). for (Value operand : valuesDefinedAbove) { if (auto rt = operand.getType().dyn_cast()) { - block.addArgument(IREE::Flow::DispatchInputType::get( - rt.getShape(), rt.getElementType())); + block.addArgument(IREE::Flow::DispatchTensorType::get( + TensorAccess::ReadOnly, rt.getShape(), rt.getElementType())); } else { block.addArgument(operand.getType()); } Value bbArg = block.getArguments().back(); Value repl = bbArg; - if (bbArg.getType().isa()) { - repl = b.create(loc, operand.getType(), - bbArg); - } else if (bbArg.getType().isa()) { + if (bbArg.getType().isa()) { + repl = b.create(loc, operand.getType(), + bbArg); + } else if (bbArg.getType().isa()) { // TODO(nicolasvasilache): do something useful. continue; } @@ -408,7 +409,7 @@ static LogicalResult legalizeDispatchWorkgroupOperands( // The only existing arguments are for the outputs. Just need to add a new // argument for the outputs and remap the value to use the new argument. for (auto ba : block.getArguments().take_front(numOldBBArgs)) { - assert(ba.getType().isa()); + assert(ba.getType().isa()); map.map(ba, block.addArgument(ba.getType())); } @@ -546,7 +547,7 @@ struct TileAndDistributeOnTensorsPattern linalgOp.getLoc(), result, rewriter)); } - // Note: DispatchOutputStoreOp generated by the + // Note: DispatchTensorStoreOp generated by the // `buildOperandLessFlowDispatchWorkgroupOp` is an abstraction jump that // consumes the SSA value produced by `clonedOp` but it does not comply with // the semantics of DispatchWorkgroupsOp which explicitly states: "behavior diff --git a/iree/compiler/Dialect/Flow/Transforms/FormStreams.cpp b/iree/compiler/Dialect/Flow/Transforms/FormStreams.cpp index 21c51f07831b..85fbbdc06e0a 100644 --- a/iree/compiler/Dialect/Flow/Transforms/FormStreams.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/FormStreams.cpp @@ -120,6 +120,7 @@ class FormStreamsPass : public PassWrapper { SmallVector fragmentResults; SmallVector fragmentResultDims; SmallVector fragmentResultTypes; + SmallVector fragmentTiedOperands; for (auto *op : streamOps) { for (auto operand : op->getOperands()) { if (std::find(fragmentOperands.begin(), fragmentOperands.end(), @@ -159,7 +160,7 @@ class FormStreamsPass : public PassWrapper { // Create the fragment and clone in all of the ops. auto fragmentOp = blockBuilder.create( fragmentLoc, fragmentResultTypes, fragmentResultDims, fragmentOperands, - fragmentOperandDims); + fragmentOperandDims, fragmentTiedOperands); auto *entryBlock = new Block(); fragmentOp.body().getBlocks().push_back(entryBlock); entryBlock->addArguments(TypeRange(fragmentOp.operands())); diff --git a/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp b/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp index 7816d6700178..81ec219d4073 100644 --- a/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp @@ -96,7 +96,7 @@ LogicalResult convertToDispatchOp(DispatchRegionOp regionOp, auto dispatchOp = builder.create( regionOp.getLoc(), entryPointOp, ValueRange{regionOp.workload()}, outlinedFuncOp.getType().getResults(), resultDynamicDims, newArgs, - operandDynamicDims); + operandDynamicDims, ArrayRef{}); if (traceDispatchTensors) { std::string str = "Output for " + std::string(outlinedFuncOp.getName()); diff --git a/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions2.cpp b/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions2.cpp index 1ffbd9c86aca..da32fcbcea90 100644 --- a/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions2.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions2.cpp @@ -72,10 +72,12 @@ static LogicalResult convertToDispatchOp(DispatchWorkgroupsOp regionOp, } // Create the dispatch op to the executable function. + // Note that we copy the tied operand indices from the workgroups op - it + // lines up 1:1 with the dispatch once we've outlined things. auto dispatchOp = builder.create( regionOp.getLoc(), entryPointOp, regionOp.workgroup_count(), regionOp.getResultTypes(), resultDynamicDims, newOperands, - operandDynamicDims); + operandDynamicDims, regionOp.getTiedResultOperandIndices()); // Replace uses of the existing results with the new results. for (int i = 0; i < regionOp.getNumResults(); ++i) { @@ -102,13 +104,9 @@ static FuncOp createWorkgroupFunc(Location loc, StringRef functionName, SmallVector operandTypes; int64_t totalDynamicDims = 0; for (auto &operand : region.getArguments()) { - if (auto inputType = operand.getType().dyn_cast()) { - operandTypes.push_back(inputType); - totalDynamicDims += inputType.getNumDynamicDims(); - } else if (auto outputType = - operand.getType().dyn_cast()) { - operandTypes.push_back(outputType); - totalDynamicDims += outputType.getNumDynamicDims(); + if (auto tensorType = operand.getType().dyn_cast()) { + operandTypes.push_back(tensorType); + totalDynamicDims += tensorType.getNumDynamicDims(); } else { // Pass-through. operandTypes.push_back(operand.getType()); diff --git a/iree/compiler/Dialect/Flow/Transforms/test/create_benchmark_funcs.mlir b/iree/compiler/Dialect/Flow/Transforms/test/create_benchmark_funcs.mlir index 6f03139ee293..804ea79a74f0 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/create_benchmark_funcs.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/create_benchmark_funcs.mlir @@ -12,8 +12,8 @@ module { // CHECK: func @two_dispatch_ex_dispatch_0_entry // CHECK: %{{.+}} = flow.variable.load @[[IN0_0]] : tensor<5x3xf32> // CHECK: %{{.+}} = flow.variable.load @[[IN0_1]] : tensor<3x5xf32> -// CHECK: %[[RES:.+]] = flow.ex.stream.fragment({{.+}}) -> (tensor<5x5xf32>) = -// CHECK: %[[DISPATCH_RES:.+]] = flow.dispatch @two_dispatch_ex_dispatch_0::@two_dispatch_ex_dispatch_0[%{{.+}}](%{{.+}}, %{{.+}}) : (tensor<5x3xf32>, tensor<3x5xf32>) -> (tensor<5x5xf32>) +// CHECK: %[[RES:.+]] = flow.ex.stream.fragment({{.+}}) -> tensor<5x5xf32> = +// CHECK: %[[DISPATCH_RES:.+]] = flow.dispatch @two_dispatch_ex_dispatch_0::@two_dispatch_ex_dispatch_0[%{{.+}}](%{{.+}}, %{{.+}}) : (tensor<5x3xf32>, tensor<3x5xf32>) -> tensor<5x5xf32> // CHECK: flow.return %[[DISPATCH_RES]] : tensor<5x5xf32> // CHECK: return %[[RES]] : tensor<5x5xf32> // @@ -22,8 +22,8 @@ module { // CHECK: func @two_dispatch_ex_dispatch_1_entry // CHECK: %{{.+}} = flow.variable.load @[[IN1_0]] : tensor<3x5xf32> // CHECK: %{{.+}} = flow.variable.load @[[IN1_1]] : tensor<5x5xf32> -// CHECK: %[[RES:.+]] = flow.ex.stream.fragment({{.+}}) -> (tensor<3x5xf32>) -// CHECK: %[[DISPATCH_RES:.+]] = flow.dispatch @two_dispatch_ex_dispatch_1::@two_dispatch_ex_dispatch_1[%{{.+}}](%{{.+}}, %{{.+}}) : (tensor<3x5xf32>, tensor<5x5xf32>) -> (tensor<3x5xf32>) +// CHECK: %[[RES:.+]] = flow.ex.stream.fragment({{.+}}) -> tensor<3x5xf32> +// CHECK: %[[DISPATCH_RES:.+]] = flow.dispatch @two_dispatch_ex_dispatch_1::@two_dispatch_ex_dispatch_1[%{{.+}}](%{{.+}}, %{{.+}}) : (tensor<3x5xf32>, tensor<5x5xf32>) -> tensor<3x5xf32> // CHECK: flow.return %[[DISPATCH_RES]] : tensor<3x5xf32> // CHECK: return %[[RES]] : tensor<3x5xf32> // diff --git a/iree/compiler/Dialect/Flow/Transforms/test/deduplicate_executables.mlir b/iree/compiler/Dialect/Flow/Transforms/test/deduplicate_executables.mlir index 916becdb65bd..d5eaee6cf93d 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/deduplicate_executables.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/deduplicate_executables.mlir @@ -13,7 +13,7 @@ flow.executable @single_executable_ex_0 { // CHECK-LABEL: func @single_executable func @single_executable(%arg0: tensor<4xf32>) -> tensor<4xf32> { %c4 = constant 4 : index - // CHECK: %0 = flow.dispatch @single_executable_ex_0::@single_executable_entry_0[%c4](%arg0) : (tensor<4xf32>) -> (tensor<4xf32>) + // CHECK: %0 = flow.dispatch @single_executable_ex_0::@single_executable_entry_0[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xf32> %0 = flow.dispatch @single_executable_ex_0::@single_executable_entry_0[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } @@ -53,11 +53,11 @@ flow.executable @duplicate_executables_ex_2 { // CHECK-LABEL: func @duplicate_executables func @duplicate_executables(%arg0: tensor<4xf32>) -> tensor<4xf32> { %c4 = constant 4 : index - // CHECK: %0 = flow.dispatch @duplicate_executables_ex_0::@duplicate_executables_entry_0[%c4](%arg0) : (tensor<4xf32>) -> (tensor<4xf32>) + // CHECK: %0 = flow.dispatch @duplicate_executables_ex_0::@duplicate_executables_entry_0[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xf32> %0 = flow.dispatch @duplicate_executables_ex_0::@duplicate_executables_entry_0[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> - // CHECK: %1 = flow.dispatch @duplicate_executables_ex_0::@duplicate_executables_entry_0[%c4](%arg0) : (tensor<4xf32>) -> (tensor<4xf32>) + // CHECK: %1 = flow.dispatch @duplicate_executables_ex_0::@duplicate_executables_entry_0[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xf32> %1 = flow.dispatch @duplicate_executables_ex_1::@duplicate_executables_entry_1[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> - // CHECK: %2 = flow.dispatch @duplicate_executables_ex_2::@duplicate_executables_entry_2[%c4](%arg0) : (tensor<4xf32>) -> (tensor<4xf32>) + // CHECK: %2 = flow.dispatch @duplicate_executables_ex_2::@duplicate_executables_entry_2[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xf32> %2 = flow.dispatch @duplicate_executables_ex_2::@duplicate_executables_entry_2[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } @@ -87,9 +87,9 @@ flow.executable @same_ops_diff_operands_ex_1 { // CHECK-LABEL: func @same_ops_diff_operands func @same_ops_diff_operands(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi32> { %c4 = constant 4 : index - // CHECK: %0 = flow.dispatch @same_ops_diff_operands_ex_0::@entry_0[%c4](%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> (tensor<2xi32>) + // CHECK: %0 = flow.dispatch @same_ops_diff_operands_ex_0::@entry_0[%c4](%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> %0 = flow.dispatch @same_ops_diff_operands_ex_0::@entry_0[%c4] (%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - // CHECK: %1 = flow.dispatch @same_ops_diff_operands_ex_1::@entry_1[%c4](%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> (tensor<2xi32>) + // CHECK: %1 = flow.dispatch @same_ops_diff_operands_ex_1::@entry_1[%c4](%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> %1 = flow.dispatch @same_ops_diff_operands_ex_1::@entry_1[%c4] (%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> return %0 : tensor<2xi32> } @@ -129,13 +129,13 @@ flow.executable @multiple_entry_points_ex_1 { // CHECK-LABEL: func @multiple_entry_points func @multiple_entry_points(%arg0: tensor<4xf32>) -> tensor<4xf32> { %c4 = constant 4 : index - // CHECK: %0 = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_0[%c4](%arg0) : (tensor<4xf32>) -> (tensor<4xf32>) + // CHECK: %0 = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_0[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xf32> %0 = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_0[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> - // CHECK: %1 = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_1[%c4](%arg0) : (tensor<4xf32>) -> (tensor<4xf32>) + // CHECK: %1 = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_1[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xf32> %1 = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_1[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> - // CHECK: %2 = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_0[%c4](%arg0) : (tensor<4xf32>) -> (tensor<4xf32>) + // CHECK: %2 = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_0[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xf32> %2 = flow.dispatch @multiple_entry_points_ex_1::@multiple_entry_points_1_entry_0[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> - // CHECK: %3 = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_1[%c4](%arg0) : (tensor<4xf32>) -> (tensor<4xf32>) + // CHECK: %3 = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_1[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xf32> %3 = flow.dispatch @multiple_entry_points_ex_1::@multiple_entry_points_1_entry_1[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } @@ -165,9 +165,9 @@ flow.executable @different_types_int_ex { // CHECK-LABEL: func @different_types func @different_types(%arg0: tensor<4xf32>) -> tensor<4xi1> { %c4 = constant 4 : index - // CHECK: %0 = flow.dispatch @different_types_float_ex::@different_types_float_entry[%c4](%arg0) : (tensor<4xf32>) -> (tensor<4xi1>) + // CHECK: %0 = flow.dispatch @different_types_float_ex::@different_types_float_entry[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xi1> %0 = flow.dispatch @different_types_float_ex::@different_types_float_entry[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xi1> - // CHECK: %1 = flow.dispatch @different_types_int_ex::@different_types_int_entry[%c4](%arg0) : (tensor<4xf32>) -> (tensor<4xi1>) + // CHECK: %1 = flow.dispatch @different_types_int_ex::@different_types_int_entry[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xi1> %1 = flow.dispatch @different_types_int_ex::@different_types_int_entry[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xi1> return %0 : tensor<4xi1> } @@ -222,11 +222,11 @@ flow.executable @nested_ops_ex_2 { // CHECK-LABEL: func @nested_ops func @nested_ops(%arg0: tensor<1x4xi32>) -> tensor<1xi32> { %c4 = constant 4 : index - // CHECK: %0 = flow.dispatch @nested_ops_ex_0::@nested_ops_entry_0[%c4](%arg0) : (tensor<1x4xi32>) -> (tensor<1xi32>) + // CHECK: %0 = flow.dispatch @nested_ops_ex_0::@nested_ops_entry_0[%c4](%arg0) : (tensor<1x4xi32>) -> tensor<1xi32> %0 = flow.dispatch @nested_ops_ex_0::@nested_ops_entry_0[%c4] (%arg0) : (tensor<1x4xi32>) -> tensor<1xi32> - // CHECK: %1 = flow.dispatch @nested_ops_ex_0::@nested_ops_entry_0[%c4](%arg0) : (tensor<1x4xi32>) -> (tensor<1xi32>) + // CHECK: %1 = flow.dispatch @nested_ops_ex_0::@nested_ops_entry_0[%c4](%arg0) : (tensor<1x4xi32>) -> tensor<1xi32> %1 = flow.dispatch @nested_ops_ex_0::@nested_ops_entry_0[%c4] (%arg0) : (tensor<1x4xi32>) -> tensor<1xi32> - // CHECK: %2 = flow.dispatch @nested_ops_ex_2::@nested_ops_entry_2[%c4](%arg0) : (tensor<1x4xi32>) -> (tensor<1xi32>) + // CHECK: %2 = flow.dispatch @nested_ops_ex_2::@nested_ops_entry_2[%c4](%arg0) : (tensor<1x4xi32>) -> tensor<1xi32> %2 = flow.dispatch @nested_ops_ex_2::@nested_ops_entry_2[%c4] (%arg0) : (tensor<1x4xi32>) -> tensor<1xi32> return %0 : tensor<1xi32> } diff --git a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir index ac5225c94c90..4bbb5536f7b3 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir @@ -18,11 +18,11 @@ func @tensor() -> tensor<2x4xf32> { // %[[C2]] will be handled by a later RematerializeDispatchConstants // CHECK: flow.dispatch.workgroups[%[[C4wg]], %[[C2wg]], %[[C1wg]]](%[[outerA]], %[[outerB]], %[[outerC]]) : - // CHECK-SAME: (tensor<2x3xf32>, tensor<3x4xf32>, tensor<2x4xf32>) -> (tensor<2x4xf32>) = - // CHECK-NEXT: (%[[A:[0-9a-z]*]]: !flow.dispatch.input<2x3xf32>, - // CHECK-SAME: %[[B:[0-9a-z]*]]: !flow.dispatch.input<3x4xf32>, - // CHECK-SAME: %[[C:[0-9a-z]*]]: !flow.dispatch.input<2x4xf32>, - // CHECK-SAME: %[[OUT:[0-9a-z]*]]: !flow.dispatch.output<2x4xf32>) { + // CHECK-SAME: (tensor<2x3xf32>, tensor<3x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> = + // CHECK-NEXT: (%[[A:[0-9a-z]*]]: !flow.dispatch.tensor, + // CHECK-SAME: %[[B:[0-9a-z]*]]: !flow.dispatch.tensor, + // CHECK-SAME: %[[C:[0-9a-z]*]]: !flow.dispatch.tensor, + // CHECK-SAME: %[[OUT:[0-9a-z]*]]: !flow.dispatch.tensor) { // CHECK-DAG: %[[C0:.*]] = constant 0 : index // CHECK-DAG: %[[C1:.*]] = constant 1 : index // CHECK-DAG: %[[C2:.*]] = constant 2 : index @@ -38,23 +38,23 @@ func @tensor() -> tensor<2x4xf32> { // CHECK-NEXT: scf.for %[[J:.*]] = %[[bix_scaled]] to %[[C4]] step %[[bdx_scaled]] { // Canonicalizations not yet powerful enough here. // CHECK-NEXT: %[[MIN_I:.*]] = affine.min{{.*}}(%[[I]]) - // CHECK-NEXT: %[[AA:.*]] = flow.dispatch.input.load %[[A]], + // CHECK-NEXT: %[[AA:.*]] = flow.dispatch.tensor.load %[[A]], // CHECK-SAME: offsets = [%[[I]], %[[C0]]], sizes = [%[[MIN_I]], %[[C3]]], strides = [%[[C1]], %[[C1]]] : - // CHECK-SAME: !flow.dispatch.input<2x3xf32> -> tensor + // CHECK-SAME: !flow.dispatch.tensor -> tensor // // Canonicalizations not yet powerful enough here. // CHECK-NEXT: %[[MIN_J:.*]] = affine.min{{.*}}(%[[J]]) - // CHECK-NEXT: %[[BB:.*]] = flow.dispatch.input.load %[[B]], + // CHECK-NEXT: %[[BB:.*]] = flow.dispatch.tensor.load %[[B]], // CHECK-SAME: offsets = [%[[C0]], %[[J]]], sizes = [%[[C3]], %[[MIN_J]]], strides = [%[[C1]], %[[C1]]] : - // CHECK-SAME: !flow.dispatch.input<3x4xf32> -> tensor<3x?xf32> - // CHECK-NEXT: %[[CC:.*]] = flow.dispatch.input.load %[[C]], + // CHECK-SAME: !flow.dispatch.tensor -> tensor<3x?xf32> + // CHECK-NEXT: %[[CC:.*]] = flow.dispatch.tensor.load %[[C]], // CHECK-SAME: offsets = [%[[I]], %[[J]]], sizes = [%[[MIN_I]], %[[MIN_J]]], strides = [%[[C1]], %[[C1]]] : - // CHECK-SAME: !flow.dispatch.input<2x4xf32> -> tensor + // CHECK-SAME: !flow.dispatch.tensor -> tensor // CHECK-NEXT: %[[RES:.*]] = linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%[[AA]], %[[BB]] : // CHECK-SAME: tensor, tensor<3x?xf32>) outs(%[[CC]] : tensor) -> tensor - // CHECK-NEXT: flow.dispatch.output.store %[[RES]], %[[OUT]], + // CHECK-NEXT: flow.dispatch.tensor.store %[[RES]], %[[OUT]], // CHECK-SAME: offsets = [%[[I]], %[[J]]], sizes = [%[[MIN_I]], %[[MIN_J]]], strides = [%[[C1]], %[[C1]]] : - // CHECK-SAME: tensor -> !flow.dispatch.output<2x4xf32> + // CHECK-SAME: tensor -> !flow.dispatch.tensor // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: flow.return @@ -163,12 +163,12 @@ func @tensor5(%A: tensor, %B: tensor, %C: tensor) // CHECK-DAG: %[[D0:.+]] = dim %[[ARG2]], %[[C0]] // CHECK-DAG: %[[D1:.+]] = dim %[[ARG2]], %[[C1]] // CHECK: %[[origCC:.+]] = flow.dispatch.workgroups[%[[D1]], %[[D0]], %[[C1]]](%[[ARG2]]) - // CHECK-NEXT: %[[ARG3:.+]]: !flow.dispatch.input - // CHECK-SAME: %[[ARG4:.+]]: !flow.dispatch.output - // CHECK: %[[LOAD:.+]] = flow.dispatch.input.load %[[ARG3]] + // CHECK-NEXT: %[[ARG3:.+]]: !flow.dispatch.tensor + // CHECK-SAME: %[[ARG4:.+]]: !flow.dispatch.tensor + // CHECK: %[[LOAD:.+]] = flow.dispatch.tensor.load %[[ARG3]] // CHECK: %[[STOREVAL:.+]] = linalg.generic // CHECK-SAME: outs(%[[LOAD]] : tensor) - // CHECK: flow.dispatch.output.store %[[STOREVAL]], %[[ARG4]] + // CHECK: flow.dispatch.tensor.store %[[STOREVAL]], %[[ARG4]] // linalg.generic is fused inside the dispatch region and becomes a noop but // there is still a use. diff --git a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_dynamic.mlir b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_dynamic.mlir index 8f69356d7826..e32e64940582 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_dynamic.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_dynamic.mlir @@ -14,10 +14,10 @@ func @tensor(%arg0 : tensor, %arg1 : tensor, // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor // CHECK: flow.dispatch.workgroups // CHECK-SAME: (%[[ARG0]], %[[ARG1]], %[[ARG2]]) -// CHECK-NEXT: %[[ARG3:[a-zA-Z0-9_]+]]: !flow.dispatch.input -// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: !flow.dispatch.input -// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: !flow.dispatch.input -// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: !flow.dispatch.output +// CHECK-NEXT: %[[ARG3:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor +// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor +// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor // CHECK-DAG: %[[C0:.+]] = constant 0 : index // CHECK-DAG: %[[WGSIZE_X:.+]] = flow.dispatch.workgroup.size[0] // CHECK-DAG: %[[WGSIZE_Y:.+]] = flow.dispatch.workgroup.size[1] @@ -33,16 +33,16 @@ func @tensor(%arg0 : tensor, %arg1 : tensor, // CHECK: %[[STEP_X:.+]] = affine.apply #[[MULMAP]]()[%[[WGCOUNT_X]], %[[WGSIZE_X]]] // CHECK: scf.for %[[ARG8:.+]] = %[[OFFSET_X]] // CHECK-SAME: to %{{.+}} step %[[STEP_X]] -// CHECK: %[[LHS:.+]] = flow.dispatch.input.load %[[ARG3]] +// CHECK: %[[LHS:.+]] = flow.dispatch.tensor.load %[[ARG3]] // CHECK-SAME: offsets = [%[[ARG7]], %[[C0]]] -// CHECK: %[[RHS:.+]] = flow.dispatch.input.load %[[ARG4]] +// CHECK: %[[RHS:.+]] = flow.dispatch.tensor.load %[[ARG4]] // CHECK-SAME: offsets = [%[[C0]], %[[ARG8]]] -// CHECK: %[[INIT:.+]] = flow.dispatch.input.load %[[ARG5]] +// CHECK: %[[INIT:.+]] = flow.dispatch.tensor.load %[[ARG5]] // CHECK-SAME: offsets = [%[[ARG7]], %[[ARG8]]] // CHECK: %[[RESULT:.+]] = linalg.matmul // CHECK-SAME: ins(%[[LHS]], %[[RHS]] : tensor, tensor) // CHECK-SAME: outs(%[[INIT]] : tensor) -// CHECK: flow.dispatch.output.store %[[RESULT]], %[[ARG6]] +// CHECK: flow.dispatch.tensor.store %[[RESULT]], %[[ARG6]] // CHECK-SAME: offsets = [%[[ARG7]], %[[ARG8]]] // ----- @@ -75,18 +75,18 @@ func @generic_op(%A: tensor, %B: tensor) -> tensor { // CHECK-DAG: %[[D1:.+]] = dim %[[ARG0]], %[[C1]] // CHECK: flow.dispatch.workgroups // CHECK-SAME: [%[[D1]], %[[D0]], %[[C1]]](%[[ARG0]], %[[ARG1]], %[[D0]], %[[D1]]) -// CHECK-NEXT: %[[ARG2:[a-zA-Z0-9_]+]]: !flow.dispatch.input -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: !flow.dispatch.input +// CHECK-NEXT: %[[ARG2:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor // CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: index // CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: index -// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: !flow.dispatch.output -// CHECK-DAG: %[[LOAD2:.+]] = flow.dispatch.input.load %[[ARG2]] : !flow.dispatch.input +// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor +// CHECK-DAG: %[[LOAD2:.+]] = flow.dispatch.tensor.load %[[ARG2]] : !flow.dispatch.tensor // CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor [%[[ARG4]], %[[ARG5]]] -// CHECK-DAG: %[[LOAD3:.+]] = flow.dispatch.input.load %[[ARG3]] : !flow.dispatch.input +// CHECK-DAG: %[[LOAD3:.+]] = flow.dispatch.tensor.load %[[ARG3]] : !flow.dispatch.tensor // CHECK: %[[RESULT:.+]] = linalg.generic // CHECK-SAME: ins(%[[LOAD2]], %[[LOAD3]] : tensor, tensor) // CHECK-SAME: outs(%[[INIT]] : tensor) -// CHECK: flow.dispatch.output.store %[[RESULT]], %[[ARG6]] +// CHECK: flow.dispatch.tensor.store %[[RESULT]], %[[ARG6]] // ----- @@ -113,20 +113,20 @@ func @fuse_fill_with_producer(%A : tensor, %B : tensor) -> ten // CHECK-SAME: (%[[M]], %[[N]], %[[ARG0]], %[[ARG1]]) // CHECK-NEXT: (%[[ARG2:[a-zA-Z0-9_]+]]: index // CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index -// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: !flow.dispatch.input -// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: !flow.dispatch.input -// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: !flow.dispatch.output) { +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor +// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor +// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor) { // CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32 // CHECK: scf.for // CHECK: scf.for -// CHECK-DAG: %[[LHS_TILE:.+]] = flow.dispatch.input.load %[[ARG4]] -// CHECK-DAG: %[[RHS_TILE:.+]] = flow.dispatch.input.load %[[ARG5]] +// CHECK-DAG: %[[LHS_TILE:.+]] = flow.dispatch.tensor.load %[[ARG4]] +// CHECK-DAG: %[[RHS_TILE:.+]] = flow.dispatch.tensor.load %[[ARG5]] // CHECK-DAG: %[[INIT_TILE:.+]] = linalg.init_tensor // CHECK: %[[FILL_TILE:.+]] = linalg.fill(%[[INIT_TILE]], %[[ZERO]]) // CHECK: %[[RESULT_TILE:.+]] = linalg.matmul // CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] : tensor, tensor) // CHECK-SAME: outs(%[[FILL_TILE]] : tensor) -// CHECK: flow.dispatch.output.store %[[RESULT_TILE]], %[[ARG6]] +// CHECK: flow.dispatch.tensor.store %[[RESULT_TILE]], %[[ARG6]] // CHECK: flow.return // CHECK: } @@ -166,17 +166,17 @@ func @two_dispatches(%A : tensor, %B : tensor) -> tensor +// CHECK-NEXT: (%[[ARG2:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor // CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index // CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: index -// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: !flow.dispatch.output) { +// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor) { // CHECK: %[[ONE:.+]] = constant 1.0 -// CHECK-DAG: %[[INPUT:.+]] = flow.dispatch.input.load %[[ARG2]] +// CHECK-DAG: %[[INPUT:.+]] = flow.dispatch.tensor.load %[[ARG2]] // CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor // CHECK: %[[RESULT:.+]] = linalg.generic // CHECK-SAME: ins(%[[INPUT]] : tensor) // CHECK-SAME: outs(%[[INIT]] : tensor) -// CHECK: flow.dispatch.output.store %[[RESULT]], %[[ARG5]] +// CHECK: flow.dispatch.tensor.store %[[RESULT]], %[[ARG5]] // CHECK: flow.return // CHECK: } // CHECK: flow.dispatch.workgroups[%[[N]], %[[M]], %[[C1]]] @@ -192,21 +192,21 @@ func @two_dispatches(%A : tensor, %B : tensor) -> tensor -// NOCHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: !flow.dispatch.input -// NOCHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: !flow.dispatch.input -// NOCHECK-SAME: %[[ARG7:[a-zA-Z0-9_]+]]: !flow.dispatch.output) { +// NOCHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor +// NOCHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor +// NOCHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor +// NOCHECK-SAME: %[[ARG7:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor) { // NOCHECK: %[[ZERO:.+]] = constant 0.0 // NOCHECK: scf.for // NOCHECK: scf.for -// NOCHECK-DAG: %[[LHS_TILE_2:.+]] = flow.dispatch.input.load %[[ARG6]] -// NOCHECK-DAG: %[[RHS_TILE_2:.+]] = flow.dispatch.input.load %[[ARG5]] +// NOCHECK-DAG: %[[LHS_TILE_2:.+]] = flow.dispatch.tensor.load %[[ARG6]] +// NOCHECK-DAG: %[[RHS_TILE_2:.+]] = flow.dispatch.tensor.load %[[ARG5]] // NOCHECK-DAG: %[[INIT_TILE_2:.+]] = linalg.init_tensor // NOCHECK: %[[FILL_TILE:.+]] = linalg.fill(%[[INIT_TILE]], %[[ZERO]]) // NOCHECK: %[[RESULT_TILE_2:.++]]] = linalg.matmul // NOCHECK-SAME: ins(%[[LHS_TILE_2]], %[[RHS_TILE_2]] : tensor, tensor) // NOCHECK: outs(%[[FILL_TILE_2]] : tensor) -// NOCHECK: flow.dispatch.output.store %[[RESULT_TILE_2]], %[[ARG7]] +// NOCHECK: flow.dispatch.tensor.store %[[RESULT_TILE_2]], %[[ARG7]] // NOCHECK: flow.return // NOCHECK: } @@ -228,22 +228,22 @@ func @dot_general_lower() attributes {iree.module.export} { } // CHECK-LABEL: func @dot_general_lower // CHECK: flow.dispatch.workgroups[%{{.+}}, %{{.+}}, %{{.+}}] -// CHECK-NEXT: %[[ARG0:[a-zA-Z0-9_]+]]: !flow.dispatch.input<1x1x2xf32> -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: !flow.dispatch.input<2x3xf32> -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: !flow.dispatch.output<1x3xf32> +// CHECK-NEXT: %[[ARG0:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor // CHECK-DAG: %[[ZERO:.+]] = constant 0.0 -// CHECK: %[[LOAD:.+]] = flow.dispatch.input.load %[[ARG0]] +// CHECK: %[[LOAD:.+]] = flow.dispatch.tensor.load %[[ARG0]] // CHECK: %[[RESHAPE:.+]] = linalg.tensor_reshape %[[LOAD]] // CHECK: scf.for // CHECK: scf.for // CHECK-DAG: %[[LHS:.+]] = subtensor %[[RESHAPE]] -// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.input.load %[[ARG1]] +// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[ARG1]] // CHECK: %[[INIT:.+]] = linalg.init // CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]]) // CHECK: %[[RESULT:.+]] = linalg.matmul // CHECK-SAME: ins(%[[LHS]], %[[RHS]] : tensor, tensor<2x?xf32>) // CHECK-SAME: outs(%[[FILL]] : tensor) -// CHECK: flow.dispatch.output.store %[[RESULT]], %[[ARG2]] +// CHECK: flow.dispatch.tensor.store %[[RESULT]], %[[ARG2]] // ----- @@ -263,11 +263,11 @@ func @reshapeop(%arg0: tensor) -> tensor // CHECK: %[[WORKLOAD:.+]] = affine.apply #[[MAP0]]()[%[[D0]], %[[D1]]] // CHECK: %[[RESULT:.+]] = flow.dispatch.workgroups // CHECK-SAME: [%[[WORKLOAD]], %[[C1]], %[[C1]]](%[[ARG0]]) -// CHECK-NEXT: %[[ARG1:.+]]: !flow.dispatch.input -// CHECK-SAME: %[[ARG2:.+]]: !flow.dispatch.output -// CHECK: %[[LOAD:.+]] = flow.dispatch.input.load %[[ARG1]] +// CHECK-NEXT: %[[ARG1:.+]]: !flow.dispatch.tensor +// CHECK-SAME: %[[ARG2:.+]]: !flow.dispatch.tensor +// CHECK: %[[LOAD:.+]] = flow.dispatch.tensor.load %[[ARG1]] // CHECK: %[[RESHAPE:.+]] = linalg.tensor_reshape %[[LOAD]] [#[[MAP1]]] -// CHECK: flow.dispatch.output.store %[[RESHAPE]], %[[ARG2]] +// CHECK: flow.dispatch.tensor.store %[[RESHAPE]], %[[ARG2]] // ----- @@ -355,8 +355,8 @@ func @always_fuse_reshape // ----- -func @pad_test(%arg0 : tensor, %arg1 : tensor, %arg2 : index, - %arg3 : index, %arg4 : index, %arg5 : index ) -> tensor { +func @pad_test(%arg0: tensor, %arg1: tensor, %arg2: index, + %arg3: index, %arg4: index, %arg5: index) -> tensor { %c0 = constant 0 : index %c1 = constant 1 : index %0 = tensor.extract %arg1[] : tensor @@ -390,6 +390,6 @@ func @pad_test(%arg0 : tensor, %arg1 : tensor, %arg2 : index, // CHECK-DAG: %[[VAL:.+]] = tensor.extract // CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor // CHECK: %[[RETURN:.+]] = linalg.fill(%[[INIT]], %[[VAL]]) -// CHECK: flow.dispatch.output.store %[[RETURN]] +// CHECK: flow.dispatch.tensor.store %[[RETURN]] // CHECK-NEXT: flow.return // CHECK: flow.tensor.update %[[ARG0]], %[[RESULT]] diff --git a/iree/compiler/Dialect/Flow/Transforms/test/form_streams.mlir b/iree/compiler/Dialect/Flow/Transforms/test/form_streams.mlir index 927f1b5b9363..630dd616ed6c 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/form_streams.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/form_streams.mlir @@ -5,13 +5,13 @@ func @outsideTieShape(%arg0: tensor {iree.reflection = {}}, %arg1: !shape %c0 = constant 0 : index // CHECK-DAG: %[[DIM:.+]] = shapex.ranked_dim %arg1[0] %dim = shapex.ranked_dim %arg1[0] : !shapex.ranked_shape<[?]> -> index - // CHECK-NEXT: %[[RET:.+]] = flow.ex.stream.fragment(%[[DIM]], %arg0) : (index, tensor{%[[DIM]]}) -> (tensor{%[[DIM]]}) = - // CHECK-NEXT: (%[[INNER_DIM:.+]]: index, %[[CAPTURE:.+]]: tensor) -> (tensor) { + // CHECK-NEXT: %[[RET:.+]] = flow.ex.stream.fragment(%[[DIM]], %arg0) : (index, tensor{%[[DIM]]}) -> tensor{%[[DIM]]} = + // CHECK-NEXT: (%[[INNER_DIM:.+]]: index, %[[CAPTURE:.+]]: tensor) -> tensor { // CHECK-NEXT: %[[WORKLOAD0:.+]] = constant 0 : index - // CHECK-NEXT: %[[INNER_RET:.+]] = flow.dispatch @main_ex_dispatch_1::@main_ex_dispatch_1[%[[WORKLOAD0]]](%[[INNER_DIM]], %[[CAPTURE]]) : (index, tensor{%[[INNER_DIM]]}) -> (tensor{%[[INNER_DIM]]}) + // CHECK-NEXT: %[[INNER_RET:.+]] = flow.dispatch @main_ex_dispatch_1::@main_ex_dispatch_1[%[[WORKLOAD0]]](%[[INNER_DIM]], %[[CAPTURE]]) : (index, tensor{%[[INNER_DIM]]}) -> tensor{%[[INNER_DIM]]} // CHECK-NEXT: flow.return %[[INNER_RET]] : tensor // CHECK-NEXT: } - %15 = flow.dispatch @main_ex_dispatch_1::@main_ex_dispatch_1[%c0](%dim, %arg0) : (index, tensor{%dim}) -> tensor{%dim} + %15 = flow.dispatch @main_ex_dispatch_1::@main_ex_dispatch_1[%c0](%dim, %arg0) : (index, tensor{%dim}) -> (tensor{%dim}) // CHECK-NEXT: return %[[RET]] : tensor return %15 : tensor } @@ -34,10 +34,10 @@ func @outerOps(%arg0: tensor<4xf32>) -> tensor<4xf32> { // CHECK: %0 = addf %arg0, %arg0 : tensor<4xf32> %0 = addf %arg0, %arg0 : tensor<4xf32> %cst = constant 4 : index - // CHECK-NEXT: %1 = flow.ex.stream.fragment(%0) : (tensor<4xf32>) -> (tensor<4xf32>) = - // CHECK-NEXT: (%[[INNER_ARG:.+]]: tensor<4xf32>) -> (tensor<4xf32>) { + // CHECK-NEXT: %1 = flow.ex.stream.fragment(%0) : (tensor<4xf32>) -> tensor<4xf32> = + // CHECK-NEXT: (%[[INNER_ARG:.+]]: tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 : index - // CHECK-NEXT: %3 = flow.dispatch @outerOps_ex_dispatch_0::@outerOps_rgn_dispatch_0[%[[WORKLOAD]]](%[[INNER_ARG]]) : (tensor<4xf32>) -> (tensor<4xf32>) + // CHECK-NEXT: %3 = flow.dispatch @outerOps_ex_dispatch_0::@outerOps_rgn_dispatch_0[%[[WORKLOAD]]](%[[INNER_ARG]]) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: flow.return %3 : tensor<4xf32> // CHECK-NEXT: } %1 = flow.dispatch @outerOps_ex_dispatch_0::@outerOps_rgn_dispatch_0[%cst](%0) : (tensor<4xf32>) -> tensor<4xf32> @@ -54,12 +54,12 @@ func @nondependentOuterOps(%arg0: tensor<4xf32>) -> tensor<4xf32> { %cst = constant 4 : index // CHECK-NEXT: %[[ADD1:.+]] = addf %arg0, %arg0 : tensor<4xf32> %add1 = addf %arg0, %arg0 : tensor<4xf32> - // CHECK-NEXT: %[[S:.+]] = flow.ex.stream.fragment(%arg0, %[[ADD1]]) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>) = - // CHECK-NEXT: (%arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) { + // CHECK-NEXT: %[[S:.+]] = flow.ex.stream.fragment(%arg0, %[[ADD1]]) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> = + // CHECK-NEXT: (%arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 : index - // CHECK-NEXT: %[[D1:.+]] = flow.dispatch @dispatch_1::@dispatch_1[%[[WORKLOAD]]](%arg1, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>) + // CHECK-NEXT: %[[D1:.+]] = flow.dispatch @dispatch_1::@dispatch_1[%[[WORKLOAD]]](%arg1, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> %d1 = flow.dispatch @dispatch_1::@dispatch_1[%cst](%arg0, %arg0) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - // CHECK-NEXT: %[[D2:.+]] = flow.dispatch @dispatch_2::@dispatch_2[%[[WORKLOAD]]](%[[D1]], %arg2) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>) + // CHECK-NEXT: %[[D2:.+]] = flow.dispatch @dispatch_2::@dispatch_2[%[[WORKLOAD]]](%[[D1]], %arg2) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> %d2 = flow.dispatch @dispatch_2::@dispatch_2[%cst](%d1, %add1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: flow.return %[[D2]] : tensor<4xf32> // CHECK-NEXT: } @@ -85,19 +85,19 @@ flow.executable @interleavedOuterOps_ex_dispatch_0 { // CHECK-LABEL: func @interleavedOuterOps( func @interleavedOuterOps(%arg0: tensor<4xf32>) -> tensor<4xf32> { %cst = constant 4 : index - // CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg0) : (tensor<4xf32>) -> (tensor<4xf32>) = - // CHECK-NEXT: (%arg1: tensor<4xf32>) -> (tensor<4xf32>) + // CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg0) : (tensor<4xf32>) -> tensor<4xf32> = + // CHECK-NEXT: (%arg1: tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: %[[WORKLOAD1:.+]] = constant 4 : index - // CHECK-NEXT: %3 = flow.dispatch @interleavedOuterOps_ex_dispatch_0::@interleavedOuterOps_rgn_dispatch_0[%[[WORKLOAD1]]](%arg1) : (tensor<4xf32>) -> (tensor<4xf32>) + // CHECK-NEXT: %3 = flow.dispatch @interleavedOuterOps_ex_dispatch_0::@interleavedOuterOps_rgn_dispatch_0[%[[WORKLOAD1]]](%arg1) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: flow.return %3 : tensor<4xf32> // CHECK-NEXT: } %0 = flow.dispatch @interleavedOuterOps_ex_dispatch_0::@interleavedOuterOps_rgn_dispatch_0[%cst](%arg0) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: %1 = addf %0, %0 : tensor<4xf32> %1 = addf %0, %0 : tensor<4xf32> - // CHECK-NEXT: %2 = flow.ex.stream.fragment(%1) : (tensor<4xf32>) -> (tensor<4xf32>) = - // CHECK-NEXT: (%arg1: tensor<4xf32>) -> (tensor<4xf32>) { + // CHECK-NEXT: %2 = flow.ex.stream.fragment(%1) : (tensor<4xf32>) -> tensor<4xf32> = + // CHECK-NEXT: (%arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: %[[WORKLOAD2:.+]] = constant 4 : index - // CHECK-NEXT: %3 = flow.dispatch @interleavedOuterOps_ex_dispatch_0::@interleavedOuterOps_rgn_dispatch_0[%[[WORKLOAD2]]](%arg1) : (tensor<4xf32>) -> (tensor<4xf32>) + // CHECK-NEXT: %3 = flow.dispatch @interleavedOuterOps_ex_dispatch_0::@interleavedOuterOps_rgn_dispatch_0[%[[WORKLOAD2]]](%arg1) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: flow.return %3 : tensor<4xf32> // CHECK-NEXT: } %2 = flow.dispatch @interleavedOuterOps_ex_dispatch_0::@interleavedOuterOps_rgn_dispatch_0[%cst](%1) : (tensor<4xf32>) -> tensor<4xf32> @@ -169,12 +169,12 @@ flow.executable @interleavedDot_ex_dispatch_2 { // CHECK-LABEL: func @interleavedDot( func @interleavedDot(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { %cst = constant 16 : index - // CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg0) : (tensor<4x4xf32>) -> (tensor<4x4xf32>) = - // CHECK-NEXT: (%arg1: tensor<4x4xf32>) -> (tensor<4x4xf32>) { + // CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg0) : (tensor<4x4xf32>) -> tensor<4x4xf32> = + // CHECK-NEXT: (%arg1: tensor<4x4xf32>) -> tensor<4x4xf32> { // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 16 : index - // CHECK-NEXT: %1 = flow.dispatch @interleavedDot_ex_dispatch_0::@interleavedDot_rgn_dispatch_0[%[[WORKLOAD]]](%arg1) : (tensor<4x4xf32>) -> (tensor<4x4xf32>) - // CHECK-NEXT: %2 = flow.dispatch @interleavedDot_ex_dispatch_1::@interleavedDot_rgn_dispatch_1[%[[WORKLOAD]]](%1, %arg1) : (tensor<4x4xf32>, tensor<4x4xf32>) -> (tensor<4x4xf32>) - // CHECK-NEXT: %3 = flow.dispatch @interleavedDot_ex_dispatch_2::@interleavedDot_rgn_dispatch_2[%[[WORKLOAD]]](%2, %arg1) : (tensor<4x4xf32>, tensor<4x4xf32>) -> (tensor<4x4xf32>) + // CHECK-NEXT: %1 = flow.dispatch @interleavedDot_ex_dispatch_0::@interleavedDot_rgn_dispatch_0[%[[WORKLOAD]]](%arg1) : (tensor<4x4xf32>) -> tensor<4x4xf32> + // CHECK-NEXT: %2 = flow.dispatch @interleavedDot_ex_dispatch_1::@interleavedDot_rgn_dispatch_1[%[[WORKLOAD]]](%1, %arg1) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + // CHECK-NEXT: %3 = flow.dispatch @interleavedDot_ex_dispatch_2::@interleavedDot_rgn_dispatch_2[%[[WORKLOAD]]](%2, %arg1) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> // CHECK-NEXT: flow.return %3 : tensor<4x4xf32> // CHECK-NEXT: } %0 = flow.dispatch @interleavedDot_ex_dispatch_0::@interleavedDot_rgn_dispatch_0[%cst](%arg0) : (tensor<4x4xf32>) -> tensor<4x4xf32> @@ -211,19 +211,19 @@ flow.executable @caller_ex_dispatch_1 { // CHECK-LABEL: func @caller( func @caller(%arg0: tensor<4xf32>) -> tensor<4xf32> { %cst = constant 4 : index - // CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg0) : (tensor<4xf32>) -> (tensor<4xf32>) = - // CHECK-NEXT: (%arg1: tensor<4xf32>) -> (tensor<4xf32>) { + // CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg0) : (tensor<4xf32>) -> tensor<4xf32> = + // CHECK-NEXT: (%arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: %[[WORKLOAD1:.+]] = constant 4 : index - // CHECK-NEXT: %3 = flow.dispatch @caller_ex_dispatch_0::@caller_rgn_dispatch_0[%[[WORKLOAD1]]](%arg1) : (tensor<4xf32>) -> (tensor<4xf32>) + // CHECK-NEXT: %3 = flow.dispatch @caller_ex_dispatch_0::@caller_rgn_dispatch_0[%[[WORKLOAD1]]](%arg1) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: flow.return %3 : tensor<4xf32> // CHECK-NEXT: } %0 = flow.dispatch @caller_ex_dispatch_0::@caller_rgn_dispatch_0[%cst](%arg0) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: %1 = call @callee(%0) : (tensor<4xf32>) -> tensor<4xf32> %1 = call @callee(%0) : (tensor<4xf32>) -> tensor<4xf32> - // CHECK-NEXT: %2 = flow.ex.stream.fragment(%arg0, %1) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>) = - // CHECK-NEXT: (%arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) { + // CHECK-NEXT: %2 = flow.ex.stream.fragment(%arg0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> = + // CHECK-NEXT: (%arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: %[[WORKLOAD2:.+]] = constant 4 : index - // CHECK-NEXT: %3 = flow.dispatch @caller_ex_dispatch_1::@caller_rgn_dispatch_1[%[[WORKLOAD2]]](%arg1, %arg2) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>) + // CHECK-NEXT: %3 = flow.dispatch @caller_ex_dispatch_1::@caller_rgn_dispatch_1[%[[WORKLOAD2]]](%arg1, %arg2) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: flow.return %3 : tensor<4xf32> // CHECK-NEXT: } %2 = flow.dispatch @caller_ex_dispatch_1::@caller_rgn_dispatch_1[%cst](%arg0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> @@ -242,10 +242,10 @@ flow.executable @callee_ex_dispatch_0 { // CHECK-LABEL: func @callee( func @callee(%arg0: tensor<4xf32>) -> tensor<4xf32> { %cst = constant 4 : index - // CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg0) : (tensor<4xf32>) -> (tensor<4xf32>) = - // CHECK-NEXT: (%arg1: tensor<4xf32>) -> (tensor<4xf32>) { + // CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg0) : (tensor<4xf32>) -> tensor<4xf32> = + // CHECK-NEXT: (%arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 : index - // CHECK-NEXT: %1 = flow.dispatch @callee_ex_dispatch_0::@callee_rgn_dispatch_0[%[[WORKLOAD]]](%arg1) : (tensor<4xf32>) -> (tensor<4xf32>) + // CHECK-NEXT: %1 = flow.dispatch @callee_ex_dispatch_0::@callee_rgn_dispatch_0[%[[WORKLOAD]]](%arg1) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: flow.return %1 : tensor<4xf32> // CHECK-NEXT: } %0 = flow.dispatch @callee_ex_dispatch_0::@callee_rgn_dispatch_0[%cst](%arg0) : (tensor<4xf32>) -> tensor<4xf32> @@ -253,6 +253,8 @@ func @callee(%arg0: tensor<4xf32>) -> tensor<4xf32> { return %0 : tensor<4xf32> } +// ----- + // CHECK-LABEL: @simple_unary // CHECK-SAME: %[[A0:.+]]: tensor // CHECK-SAME: %[[A1:.+]]: !shapex.ranked_shape<[?,?]> @@ -265,16 +267,15 @@ func @simple_unary(%arg0: tensor, %arg1: !shapex.ranked_shape<[?,?]>) - %2 = muli %dim0, %dim1 : index // Verify that the fragment captures the tie_shapes and marshals the indices // in as loose index values (not as ranked_shape types). - // CHECK: %[[S:.+]] = flow.ex.stream.fragment(%[[SZ]], %[[A0]], %[[DIM0]], %[[DIM1]]) : (index, tensor{%[[DIM0]], %[[DIM1]]}, index, index) -> (tensor{%[[DIM0]], %[[DIM1]]}) = - // CHECK: (%arg2: index, %arg3: tensor, %arg4: index, %arg5: index) -> (tensor) { - // CHECK: %[[STREAM_RET:.+]] = flow.dispatch @simple_unary_ex_dispatch_0{{.+}}[%arg2](%arg3, %arg4, %arg5) : (tensor{%arg4, %arg5}, index, index) -> (tensor{%arg4, %arg5}) + // CHECK: %[[S:.+]] = flow.ex.stream.fragment(%[[SZ]], %[[A0]], %[[DIM0]], %[[DIM1]]) : (index, tensor{%[[DIM0]], %[[DIM1]]}, index, index) -> tensor{%[[DIM0]], %[[DIM1]]} = + // CHECK: (%arg2: index, %arg3: tensor, %arg4: index, %arg5: index) -> tensor { + // CHECK: %[[STREAM_RET:.+]] = flow.dispatch @simple_unary_ex_dispatch_0{{.+}}[%arg2](%arg3, %arg4, %arg5) : (tensor{%arg4, %arg5}, index, index) -> tensor{%arg4, %arg5} // CHECK: return %[[STREAM_RET]] // CHECK: } %3 = flow.dispatch @simple_unary_ex_dispatch_0::@simple_unary_ex_dispatch_0[%2](%arg0, %dim0, %dim1) : (tensor{%dim0, %dim1}, index, index) -> tensor{%dim0, %dim1} return %3, %arg1 : tensor, !shapex.ranked_shape<[?,?]> } - // ----- // CHECK-LABEL: @bad_input_ordering @@ -327,7 +328,7 @@ func @ordering(%w : index) -> (tensor, tensor, tensor) { // CHECK: %[[C1:.+]] = constant 1 // CHECK-DAG: %[[D1:.+]] = flow.dispatch @dispatch_1::@dispatch_1 // CHECK-NEXT: flow.return %[[D1]] - %d1 = flow.dispatch @dispatch_1::@dispatch_1[%w](%c1) : (i32) -> (tensor) + %d1 = flow.dispatch @dispatch_1::@dispatch_1[%w](%c1) : (i32) -> tensor // CHECK: %[[SE_USER:.+]] = iree.do_not_optimize(%[[S1]]) %side_effecting_user = iree.do_not_optimize(%d1) : tensor %c2 = constant 2 : i32 @@ -335,7 +336,7 @@ func @ordering(%w : index) -> (tensor, tensor, tensor) { // CHECK: %[[C2:.+]] = constant 2 // CHECK-DAG: %[[D2:.+]] = flow.dispatch @dispatch_2::@dispatch_2 // CHECK-NEXT: flow.return %[[D2]] - %d2 = flow.dispatch @dispatch_2::@dispatch_2[%w](%c2) : (i32) -> (tensor) + %d2 = flow.dispatch @dispatch_2::@dispatch_2[%w](%c2) : (i32) -> tensor // CHECK: return %[[S1]], %[[S2]], %[[SE_USER]] return %d1, %d2, %side_effecting_user : tensor, tensor, tensor } diff --git a/iree/compiler/Dialect/Flow/Transforms/test/hoist_unstreamable_ops.mlir b/iree/compiler/Dialect/Flow/Transforms/test/hoist_unstreamable_ops.mlir index d94b95aa290e..811e3e44d672 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/hoist_unstreamable_ops.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/hoist_unstreamable_ops.mlir @@ -8,12 +8,12 @@ func @constants() { // CHECK-DAG: constant 4 : index // CHECK-DAG: constant 5 : index // CHECK-DAG: constant 6 : index - // CHECK: flow.dispatch @dispatch0::@dispatch0[%[[W]]]() : () -> (tensor) - // CHECK: flow.dispatch @dispatch1::@dispatch1[%[[W]]]() : () -> (tensor) - // CHECK: flow.dispatch @dispatch2::@dispatch2[%[[W]]]() : () -> (tensor) - // CHECK: flow.dispatch @dispatch3::@dispatch3[%[[W]]]() : () -> (tensor) - // CHECK: flow.dispatch @dispatch4::@dispatch4[%[[W]]]() : () -> (tensor) - // CHECK: flow.dispatch @dispatch5::@dispatch5[%[[W]]]() : () -> (tensor) + // CHECK: flow.dispatch @dispatch0::@dispatch0[%[[W]]]() : () -> tensor + // CHECK: flow.dispatch @dispatch1::@dispatch1[%[[W]]]() : () -> tensor + // CHECK: flow.dispatch @dispatch2::@dispatch2[%[[W]]]() : () -> tensor + // CHECK: flow.dispatch @dispatch3::@dispatch3[%[[W]]]() : () -> tensor + // CHECK: flow.dispatch @dispatch4::@dispatch4[%[[W]]]() : () -> tensor + // CHECK: flow.dispatch @dispatch5::@dispatch5[%[[W]]]() : () -> tensor %w = constant 1 : index %d0 = flow.dispatch @dispatch0::@dispatch0[%w]() : () -> tensor %c2 = constant 2 : index diff --git a/iree/compiler/Dialect/Flow/Transforms/test/inject_dispatch_tracing.mlir b/iree/compiler/Dialect/Flow/Transforms/test/inject_dispatch_tracing.mlir index d99d312c71c7..855df02f3182 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/inject_dispatch_tracing.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/inject_dispatch_tracing.mlir @@ -5,7 +5,7 @@ func @singleDispatch(%arg0: tensor<4xf32>) -> tensor<4xf32> { %c4 = constant 4 : index // CHECK: flow.tensor.trace {key = "ex::entry0 inputs"} %[[ARG0]] : tensor<4xf32> - // CHECK-NEXT: %[[RET0:.+]] = flow.dispatch @ex::@entry0[%c4](%[[ARG0]]) : (tensor<4xf32>) -> (tensor<4xf32>) + // CHECK-NEXT: %[[RET0:.+]] = flow.dispatch @ex::@entry0[%c4](%[[ARG0]]) : (tensor<4xf32>) -> tensor<4xf32> %0 = flow.dispatch @ex::@entry0[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: flow.tensor.trace {key = "ex::entry0 outputs"} %[[RET0]] : tensor<4xf32> // CHECK-NEXT: return %[[RET0]] @@ -20,12 +20,12 @@ func @multiDispatch(%arg0: tensor<4xf32>) -> tensor<4xf32> { %c4 = constant 4 : index // CHECK: flow.tensor.trace {key = "ex::entry0 inputs"} %[[ARG0]] : tensor<4xf32> - // CHECK-NEXT: %[[RET0:.+]] = flow.dispatch @ex::@entry0[%c4](%[[ARG0]]) : (tensor<4xf32>) -> (tensor<4xf32>) + // CHECK-NEXT: %[[RET0:.+]] = flow.dispatch @ex::@entry0[%c4](%[[ARG0]]) : (tensor<4xf32>) -> tensor<4xf32> %0 = flow.dispatch @ex::@entry0[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: flow.tensor.trace {key = "ex::entry0 outputs"} %[[RET0]] : tensor<4xf32> // CHECK: flow.tensor.trace {key = "ex::entry1 inputs"} %[[RET0]] : tensor<4xf32> - // CHECK-NEXT: %[[RET1:.+]] = flow.dispatch @ex::@entry1[%c4](%[[RET0]]) : (tensor<4xf32>) -> (tensor<4xf32>) + // CHECK-NEXT: %[[RET1:.+]] = flow.dispatch @ex::@entry1[%c4](%[[RET0]]) : (tensor<4xf32>) -> tensor<4xf32> %1 = flow.dispatch @ex::@entry1[%c4](%0) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: flow.tensor.trace {key = "ex::entry1 outputs"} %[[RET1]] : tensor<4xf32> diff --git a/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions2.mlir b/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions2.mlir index 2719aa3f8882..81e0e639d0c2 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions2.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions2.mlir @@ -5,13 +5,13 @@ // CHECK-SAME: signature = (tensor<8x4xf32>) -> tensor<4x8xf32>, // CHECK-SAME: workgroup_rank = 2 : index} // CHECK: func @staticShapeDispatch_dispatch_0( -// CHECK-SAME: %[[ARG:.+]]: !flow.dispatch.input<8x4xf32>, -// CHECK-SAME: %[[RET:.+]]: !flow.dispatch.output<4x8xf32>) { -// CHECK-DAG: %[[ARG_VALUE:.+]] = flow.dispatch.input.load %[[ARG]] : !flow.dispatch.input<8x4xf32> -> tensor<8x4xf32> -// CHECK-DAG: %[[ARG_SHAPE:.+]] = flow.dispatch.shape %[[ARG]] : !flow.dispatch.input<8x4xf32> -> !shapex.ranked_shape<[8,4]> -// CHECK-DAG: %[[RET_SHAPE:.+]] = flow.dispatch.shape %[[RET]] : !flow.dispatch.output<4x8xf32> -> !shapex.ranked_shape<[4,8]> +// CHECK-SAME: %[[ARG:.+]]: !flow.dispatch.tensor, +// CHECK-SAME: %[[RET:.+]]: !flow.dispatch.tensor) { +// CHECK-DAG: %[[ARG_VALUE:.+]] = flow.dispatch.tensor.load %[[ARG]] : !flow.dispatch.tensor -> tensor<8x4xf32> +// CHECK-DAG: %[[ARG_SHAPE:.+]] = flow.dispatch.shape %[[ARG]] : !flow.dispatch.tensor -> !shapex.ranked_shape<[8,4]> +// CHECK-DAG: %[[RET_SHAPE:.+]] = flow.dispatch.shape %[[RET]] : !flow.dispatch.tensor -> !shapex.ranked_shape<[4,8]> // CHECK-NEXT: %[[RET_VALUE:.+]] = "test.sink"(%[[ARG_VALUE]], %[[ARG_SHAPE]], %[[RET_SHAPE]]) : (tensor<8x4xf32>, !shapex.ranked_shape<[8,4]>, !shapex.ranked_shape<[4,8]>) -> tensor<4x8xf32> -// CHECK-NEXT: flow.dispatch.output.store %[[RET_VALUE]], %[[RET]] : tensor<4x8xf32> -> !flow.dispatch.output<4x8xf32> +// CHECK-NEXT: flow.dispatch.tensor.store %[[RET_VALUE]], %[[RET]] : tensor<4x8xf32> -> !flow.dispatch.tensor // CHECK-NEXT: return // CHECK-NEXT: } @@ -24,15 +24,15 @@ func @staticShapeDispatch(%arg0 : tensor<8x4xf32>) -> tensor<4x8xf32> { %y = constant 50 : index // CHECK: %[[RET:.+]] = flow.dispatch @staticShapeDispatch_dispatch_0::@staticShapeDispatch_dispatch_0[ // CHECK-SAME: %[[X]], %[[Y]] - // CHECK-SAME: ](%[[ARG0]]) : (tensor<8x4xf32>) -> (tensor<4x8xf32>) + // CHECK-SAME: ](%[[ARG0]]) : (tensor<8x4xf32>) -> tensor<4x8xf32> %0 = flow.dispatch.workgroups[%x, %y](%arg0) : (tensor<8x4xf32>) -> tensor<4x8xf32> = ( - %arg: !flow.dispatch.input<8x4xf32>, %ret: !flow.dispatch.output<4x8xf32> + %arg: !flow.dispatch.tensor, %ret: !flow.dispatch.tensor ) { - %arg_value = flow.dispatch.input.load %arg : !flow.dispatch.input<8x4xf32> -> tensor<8x4xf32> - %arg_shape = flow.dispatch.shape %arg : !flow.dispatch.input<8x4xf32> -> !shapex.ranked_shape<[8,4]> - %ret_shape = flow.dispatch.shape %ret : !flow.dispatch.output<4x8xf32> -> !shapex.ranked_shape<[4,8]> + %arg_value = flow.dispatch.tensor.load %arg : !flow.dispatch.tensor -> tensor<8x4xf32> + %arg_shape = flow.dispatch.shape %arg : !flow.dispatch.tensor -> !shapex.ranked_shape<[8,4]> + %ret_shape = flow.dispatch.shape %ret : !flow.dispatch.tensor -> !shapex.ranked_shape<[4,8]> %ret_value = "test.sink"(%arg_value, %arg_shape, %ret_shape) : (tensor<8x4xf32>, !shapex.ranked_shape<[8,4]>, !shapex.ranked_shape<[4,8]>) -> (tensor<4x8xf32>) - flow.dispatch.output.store %ret_value, %ret : tensor<4x8xf32> -> !flow.dispatch.output<4x8xf32> + flow.dispatch.tensor.store %ret_value, %ret : tensor<4x8xf32> -> !flow.dispatch.tensor flow.return } // CHECK-NEXT: return %[[RET]] @@ -62,28 +62,28 @@ func @dispatchFnMuli(%arg0 : tensor<8x4xf32>) -> tensor<8x4xf32> { %y = constant 50 : index // CHECK: %[[RET0:.+]] = flow.dispatch @dispatchFnMuli_dispatch_0::@dispatchFnMuli_dispatch_0[ // CHECK-SAME: %[[X]], %[[Y]] - // CHECK-SAME: ](%[[ARG0]]) : (tensor<8x4xf32>) -> (tensor<4x8xf32>) + // CHECK-SAME: ](%[[ARG0]]) : (tensor<8x4xf32>) -> tensor<4x8xf32> %0 = flow.dispatch.workgroups[%x, %y](%arg0) : (tensor<8x4xf32>) -> (tensor<4x8xf32>) = ( - %arg: !flow.dispatch.input<8x4xf32>, %ret: !flow.dispatch.output<4x8xf32> + %arg: !flow.dispatch.tensor, %ret: !flow.dispatch.tensor ) { - %arg_value = flow.dispatch.input.load %arg : !flow.dispatch.input<8x4xf32> -> tensor<8x4xf32> - %arg_shape = flow.dispatch.shape %arg : !flow.dispatch.input<8x4xf32> -> !shapex.ranked_shape<[8,4]> - %ret_shape = flow.dispatch.shape %ret : !flow.dispatch.output<4x8xf32> -> !shapex.ranked_shape<[4,8]> + %arg_value = flow.dispatch.tensor.load %arg : !flow.dispatch.tensor -> tensor<8x4xf32> + %arg_shape = flow.dispatch.shape %arg : !flow.dispatch.tensor -> !shapex.ranked_shape<[8,4]> + %ret_shape = flow.dispatch.shape %ret : !flow.dispatch.tensor -> !shapex.ranked_shape<[4,8]> %ret_value = "test.sink1"(%arg_value, %arg_shape, %ret_shape) : (tensor<8x4xf32>, !shapex.ranked_shape<[8,4]>, !shapex.ranked_shape<[4,8]>) -> (tensor<4x8xf32>) - flow.dispatch.output.store %ret_value, %ret : tensor<4x8xf32> -> !flow.dispatch.output<4x8xf32> + flow.dispatch.tensor.store %ret_value, %ret : tensor<4x8xf32> -> !flow.dispatch.tensor flow.return } // CHECK: %[[RET1:.+]] = flow.dispatch @dispatchFnMuli_dispatch_1::@dispatchFnMuli_dispatch_1[ // CHECK-SAME: %[[Y]], %[[X]] - // CHECK-SAME: ](%[[RET0]]) : (tensor<4x8xf32>) -> (tensor<8x4xf32>) + // CHECK-SAME: ](%[[RET0]]) : (tensor<4x8xf32>) -> tensor<8x4xf32> %1 = flow.dispatch.workgroups[%y, %x](%0) : (tensor<4x8xf32>) -> (tensor<8x4xf32>) = ( - %arg: !flow.dispatch.input<4x8xf32>, %ret: !flow.dispatch.output<8x4xf32> + %arg: !flow.dispatch.tensor, %ret: !flow.dispatch.tensor ) { - %arg_value = flow.dispatch.input.load %arg : !flow.dispatch.input<4x8xf32> -> tensor<8x4xf32> - %arg_shape = flow.dispatch.shape %arg : !flow.dispatch.input<4x8xf32> -> !shapex.ranked_shape<[4,8]> - %ret_shape = flow.dispatch.shape %ret : !flow.dispatch.output<8x4xf32> -> !shapex.ranked_shape<[8,4]> + %arg_value = flow.dispatch.tensor.load %arg : !flow.dispatch.tensor -> tensor<8x4xf32> + %arg_shape = flow.dispatch.shape %arg : !flow.dispatch.tensor -> !shapex.ranked_shape<[4,8]> + %ret_shape = flow.dispatch.shape %ret : !flow.dispatch.tensor -> !shapex.ranked_shape<[8,4]> %ret_value = "test.sink2"(%arg_value, %arg_shape, %ret_shape) : (tensor<8x4xf32>, !shapex.ranked_shape<[4,8]>, !shapex.ranked_shape<[8,4]>) -> (tensor<8x4xf32>) - flow.dispatch.output.store %ret_value, %ret : tensor<8x4xf32> -> !flow.dispatch.output<8x4xf32> + flow.dispatch.tensor.store %ret_value, %ret : tensor<8x4xf32> -> !flow.dispatch.tensor flow.return } // CHECK-NEXT: return %[[RET1]] @@ -100,7 +100,7 @@ func @dispatchFn1(%arg0 : tensor<8x4xf32>) -> tensor<4x8xf32> { %y = constant 50 : index // CHECK: flow.dispatch @dispatchFn1_dispatch_0::@dispatchFn1_dispatch_0 %0 = flow.dispatch.workgroups[%x, %y](%arg0) : (tensor<8x4xf32>) -> (tensor<4x8xf32>) = ( - %arg: !flow.dispatch.input<8x4xf32>, %ret: !flow.dispatch.output<4x8xf32> + %arg: !flow.dispatch.tensor, %ret: !flow.dispatch.tensor ) { flow.return } @@ -115,7 +115,7 @@ func @dispatchFn2(%arg0 : tensor<8x4xf32>) -> tensor<4x8xf32> { %y = constant 50 : index // CHECK: flow.dispatch @dispatchFn2_dispatch_0::@dispatchFn2_dispatch_0 %0 = flow.dispatch.workgroups[%x, %y](%arg0) : (tensor<8x4xf32>) -> (tensor<4x8xf32>) = ( - %arg: !flow.dispatch.input<8x4xf32>, %ret: !flow.dispatch.output<4x8xf32> + %arg: !flow.dispatch.tensor, %ret: !flow.dispatch.tensor ) { flow.return } @@ -129,8 +129,8 @@ func @dispatchFn2(%arg0 : tensor<8x4xf32>) -> tensor<4x8xf32> { // CHECK-SAME: signature = (tensor<7x?x24x?xf32>) -> tensor, // CHECK-SAME: workgroup_rank = 2 : index} // CHECK: func @dynamicShapeDispatch_dispatch_0( -// CHECK-SAME: %[[ARG:.+]]: !flow.dispatch.input<7x?x24x?xf32>, -// CHECK-SAME: %[[RET:.+]]: !flow.dispatch.output, +// CHECK-SAME: %[[ARG:.+]]: !flow.dispatch.tensor, +// CHECK-SAME: %[[RET:.+]]: !flow.dispatch.tensor, // CHECK-SAME: %[[IN_ARG_DIM1:.+]]: index, %[[IN_ARG_DIM3:.+]]: index, %[[IN_RET_DIM0:.+]]: index, %[[IN_RET_DIM1:.+]]: index) { // CHECK: %[[IN_ARG_SHAPE:.+]] = shapex.make_ranked_shape %[[IN_ARG_DIM1]], %[[IN_ARG_DIM3]] @@ -147,10 +147,10 @@ func @dispatchFn2(%arg0 : tensor<8x4xf32>) -> tensor<4x8xf32> { // CHECK-DAG: %[[RET_DIM1:.+]] = shapex.ranked_dim %[[RET_SHAPE]][1] // CHECK-NEXT: "test.sink_shape_ret"(%[[RET_DIM0]], %[[RET_DIM1]]) -// CHECK: %[[ARG_TILE:.+]] = flow.dispatch.input.load %[[ARG_SHAPED]] +// CHECK: %[[ARG_TILE:.+]] = flow.dispatch.tensor.load %[[ARG_SHAPED]] // CHECK-NEXT: %[[ARG_TILE_SHAPE:.+]] = shapex.get_ranked_shape %[[ARG_TILE]] // CHECK-NEXT: %[[RET_TILE:.+]] = "test.tile_math"(%[[ARG_TILE]], %[[ARG_TILE_SHAPE]], %[[RET_SHAPE]]) -// CHECK-NEXT: flow.dispatch.output.store %[[RET_TILE]], %[[RET_SHAPED]] +// CHECK-NEXT: flow.dispatch.tensor.store %[[RET_TILE]], %[[RET_SHAPED]] // CHECK: return // CHECK-NEXT: } @@ -177,29 +177,29 @@ func @dynamicShapeDispatch(%arg0 : tensor<7x?x24x?xf32>) -> tensor // CHECK-NEXT: %[[RET0:.+]] = flow.dispatch @dynamicShapeDispatch_dispatch_0::@dynamicShapeDispatch_dispatch_0[ // CHECK-SAME: %[[X]], %[[Y]] // CHECK-SAME: ](%arg0, %[[IN_ARG0_DIM1]], %[[IN_ARG0_DIM3]], %[[IN_RET0_DIM0]], %[[IN_RET0_DIM1]]) - // CHECK-SAME: : (tensor<7x?x24x?xf32>{%[[IN_ARG0_DIM1]], %[[IN_ARG0_DIM3]]}, index, index, index, index) -> (tensor{%[[IN_RET0_DIM0]], %[[IN_RET0_DIM1]]}) + // CHECK-SAME: : (tensor<7x?x24x?xf32>{%[[IN_ARG0_DIM1]], %[[IN_ARG0_DIM3]]}, index, index, index, index) -> tensor{%[[IN_RET0_DIM0]], %[[IN_RET0_DIM1]]} %ret0 = flow.dispatch.workgroups[%x, %y](%arg0) : (tensor<7x?x24x?xf32>{%dim1, %dim3}) -> tensor{%dim3, %dim1} = ( - %arg: !flow.dispatch.input<7x?x24x?xf32>, %ret: !flow.dispatch.output + %arg: !flow.dispatch.tensor, %ret: !flow.dispatch.tensor ) { %workgroup_rank = flow.dispatch.workgroup.rank : index - %arg_shape = flow.dispatch.shape %arg : !flow.dispatch.input<7x?x24x?xf32> -> !shapex.ranked_shape<[7,?,24,?]> + %arg_shape = flow.dispatch.shape %arg : !flow.dispatch.tensor -> !shapex.ranked_shape<[7,?,24,?]> %arg_dim1 = shapex.ranked_dim %arg_shape[1] : !shapex.ranked_shape<[7,?,24,?]> -> index %arg_dim3 = shapex.ranked_dim %arg_shape[3] : !shapex.ranked_shape<[7,?,24,?]> -> index "test.sink_shape_arg"(%arg_dim1, %arg_dim3) : (index, index) -> () - %ret_shape = flow.dispatch.shape %ret : !flow.dispatch.output -> !shapex.ranked_shape<[?,?,1024]> + %ret_shape = flow.dispatch.shape %ret : !flow.dispatch.tensor -> !shapex.ranked_shape<[?,?,1024]> %ret_dim0 = shapex.ranked_dim %ret_shape[0] : !shapex.ranked_shape<[?,?,1024]> -> index %ret_dim1 = shapex.ranked_dim %ret_shape[1] : !shapex.ranked_shape<[?,?,1024]> -> index "test.sink_shape_ret"(%ret_dim0, %ret_dim1) : (index, index) -> () - %arg_tile = flow.dispatch.input.load %arg : !flow.dispatch.input<7x?x24x?xf32> -> tensor<7x?x24x?xf32> + %arg_tile = flow.dispatch.tensor.load %arg : !flow.dispatch.tensor -> tensor<7x?x24x?xf32> %arg_tile_shape = shapex.get_ranked_shape %arg_tile : tensor<7x?x24x?xf32> -> !shapex.ranked_shape<[7,?,24,?]> %ret_tile = "test.tile_math"(%arg_tile, %arg_tile_shape, %ret_shape) : (tensor<7x?x24x?xf32>, !shapex.ranked_shape<[7,?,24,?]>, !shapex.ranked_shape<[?,?,1024]>) -> (tensor) - flow.dispatch.output.store %ret_tile, %ret : tensor -> !flow.dispatch.output + flow.dispatch.tensor.store %ret_tile, %ret : tensor -> !flow.dispatch.tensor flow.return } diff --git a/iree/compiler/Dialect/Flow/Transforms/test/transformation.mlir b/iree/compiler/Dialect/Flow/Transforms/test/transformation.mlir index 5c17a0a28159..5067a91cf1b1 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/transformation.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/transformation.mlir @@ -27,10 +27,10 @@ func @stdElementwiseOps(%arg0 : tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: func @stdElementwiseOps(%arg0: tensor<4xf32>) -> tensor<4xf32> { -// CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg0) : (tensor<4xf32>) -> (tensor<4xf32>) = -// CHECK-NEXT: (%arg1: tensor<4xf32>) -> (tensor<4xf32>) { +// CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg0) : (tensor<4xf32>) -> tensor<4xf32> = +// CHECK-NEXT: (%arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 : index -// CHECK-NEXT: %1 = flow.dispatch @stdElementwiseOps_ex_dispatch_0::@stdElementwiseOps_ex_dispatch_0[%[[WORKLOAD]]](%arg1) : (tensor<4xf32>) -> (tensor<4xf32> +// CHECK-NEXT: %1 = flow.dispatch @stdElementwiseOps_ex_dispatch_0::@stdElementwiseOps_ex_dispatch_0[%[[WORKLOAD]]](%arg1) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: flow.return %1 : tensor<4xf32> // CHECK-NEXT: } // CHECK-NEXT: return %0 : tensor<4xf32> @@ -57,10 +57,10 @@ func @hloElementwiseOps(%arg0 : tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: func @hloElementwiseOps(%arg0: tensor<4xf32>) -> tensor<4xf32> { -// CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg0) : (tensor<4xf32>) -> (tensor<4xf32>) = -// CHECK-NEXT: (%arg1: tensor<4xf32>) -> (tensor<4xf32>) { +// CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg0) : (tensor<4xf32>) -> tensor<4xf32> = +// CHECK-NEXT: (%arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 : index -// CHECK-NEXT: %1 = flow.dispatch @hloElementwiseOps_ex_dispatch_0::@hloElementwiseOps_ex_dispatch_0[%[[WORKLOAD]]](%arg1) : (tensor<4xf32>) -> (tensor<4xf32>) +// CHECK-NEXT: %1 = flow.dispatch @hloElementwiseOps_ex_dispatch_0::@hloElementwiseOps_ex_dispatch_0[%[[WORKLOAD]]](%arg1) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: flow.return %1 : tensor<4xf32> // CHECK-NEXT: } // CHECK-NEXT: return %0 : tensor<4xf32> @@ -103,12 +103,12 @@ func @interleavedDot(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: func @interleavedDot(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { -// CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg0) : (tensor<4x4xf32>) -> (tensor<4x4xf32>) = -// CHECK-NEXT: (%arg1: tensor<4x4xf32>) -> (tensor<4x4xf32>) { +// CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg0) : (tensor<4x4xf32>) -> tensor<4x4xf32> = +// CHECK-NEXT: (%arg1: tensor<4x4xf32>) -> tensor<4x4xf32> { // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 16 : index -// CHECK-NEXT: %1 = flow.dispatch @interleavedDot_ex_dispatch_0::@interleavedDot_ex_dispatch_0[%[[WORKLOAD]]](%arg1) : (tensor<4x4xf32>) -> (tensor<4x4xf32>) -// CHECK-NEXT: %2 = flow.dispatch @interleavedDot_ex_dispatch_1::@interleavedDot_ex_dispatch_1[%[[WORKLOAD]]](%1, %arg1) : (tensor<4x4xf32>, tensor<4x4xf32>) -> (tensor<4x4xf32>) -// CHECK-NEXT: %3 = flow.dispatch @interleavedDot_ex_dispatch_2::@interleavedDot_ex_dispatch_2[%[[WORKLOAD]]](%2, %arg1) : (tensor<4x4xf32>, tensor<4x4xf32>) -> (tensor<4x4xf32>) +// CHECK-NEXT: %1 = flow.dispatch @interleavedDot_ex_dispatch_0::@interleavedDot_ex_dispatch_0[%[[WORKLOAD]]](%arg1) : (tensor<4x4xf32>) -> tensor<4x4xf32> +// CHECK-NEXT: %2 = flow.dispatch @interleavedDot_ex_dispatch_1::@interleavedDot_ex_dispatch_1[%[[WORKLOAD]]](%1, %arg1) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> +// CHECK-NEXT: %3 = flow.dispatch @interleavedDot_ex_dispatch_2::@interleavedDot_ex_dispatch_2[%[[WORKLOAD]]](%2, %arg1) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> // CHECK-NEXT: flow.return %3 : tensor<4x4xf32> // CHECK-NEXT: } // CHECK-NEXT: return %0 : tensor<4x4xf32> @@ -141,10 +141,10 @@ func @reduction(%arg0 : tensor<4x8xf32>) -> tensor<4xf32> { // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: func @reduction(%arg0: tensor<4x8xf32>) -> tensor<4xf32> { -// CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg0) : (tensor<4x8xf32>) -> (tensor<4xf32>) = -// CHECK-NEXT: (%arg1: tensor<4x8xf32>) -> (tensor<4xf32>) { +// CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg0) : (tensor<4x8xf32>) -> tensor<4xf32> = +// CHECK-NEXT: (%arg1: tensor<4x8xf32>) -> tensor<4xf32> { // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 : index -// CHECK-NEXT: %1 = flow.dispatch @reduction_ex_dispatch_0::@reduction_ex_dispatch_0[%[[WORKLOAD]]](%arg1) : (tensor<4x8xf32>) -> (tensor<4xf32>) +// CHECK-NEXT: %1 = flow.dispatch @reduction_ex_dispatch_0::@reduction_ex_dispatch_0[%[[WORKLOAD]]](%arg1) : (tensor<4x8xf32>) -> tensor<4xf32> // CHECK-NEXT: flow.return %1 : tensor<4xf32> // CHECK-NEXT: } // CHECK-NEXT: return %0 : tensor<4xf32> @@ -172,11 +172,12 @@ func @dynamicUpdateSlice(%operand : tensor<2x4xi32>, %update : tensor<1x1xi32>, // CHECK-DAG: %[[ARG2_INDEX:.+]] = index_cast %[[ARG2_LOAD]] : i32 to index // CHECK-DAG: %[[ARG3_LOAD:.+]] = flow.tensor.load %arg3 : tensor // CHECK-DAG: %[[ARG3_INDEX:.+]] = index_cast %[[ARG3_LOAD]] : i32 to index -// CHECK-NEXT: %[[RET:.+]] = flow.ex.stream.fragment(%arg0, %[[ARG2_INDEX]], %[[ARG3_INDEX]], %arg1) : (tensor<2x4xi32>, index, index, tensor<1x1xi32>) -> (tensor<2x4xi32>) = -// CHECK-NEXT: (%arg4: tensor<2x4xi32>, %arg5: index, %arg6: index, %arg7: tensor<1x1xi32>) -> (tensor<2x4xi32>) { +// CHECK-NEXT: %[[RET:.+]] = flow.ex.stream.fragment(%arg0, %[[ARG2_INDEX]], %[[ARG3_INDEX]], %arg1) : (tensor<2x4xi32>, index, index, tensor<1x1xi32>) -> tensor<2x4xi32> = +// CHECK-NEXT: (%arg4: tensor<2x4xi32>, %arg5: index, %arg6: index, %arg7: tensor<1x1xi32>) -> tensor<2x4xi32> { // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 8 : index +// CHECK-NEXT: %[[ARG4_CLONE:.+]] = flow.tensor.clone %arg4 : tensor<2x4xi32> // CHECK-NEXT: %[[T0:.+]] = flow.tensor.update %arg7, %arg4[%arg5, %arg6] : tensor<1x1xi32> -> tensor<2x4xi32> -// CHECK-NEXT: %[[T1:.+]] = flow.dispatch @dynamicUpdateSlice_ex_dispatch_0::@dynamicUpdateSlice_ex_dispatch_0[%[[WORKLOAD]]](%arg4, %[[T0]]) : (tensor<2x4xi32>, tensor<2x4xi32>) -> (tensor<2x4xi32>) +// CHECK-NEXT: %[[T1:.+]] = flow.dispatch @dynamicUpdateSlice_ex_dispatch_0::@dynamicUpdateSlice_ex_dispatch_0[%[[WORKLOAD]]](%[[ARG4_CLONE]], %[[T0]]) : (tensor<2x4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> // CHECK-NEXT: flow.return %[[T1]] : tensor<2x4xi32> // CHECK-NEXT: } // CHECK-NEXT: return %[[RET]] : tensor<2x4xi32> diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp index 643c52160d5a..22ee0c2de3b9 100644 --- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp +++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp @@ -66,6 +66,20 @@ static bool isNoOp(Operation *op) { return isa(op); } // identity. static bool isIdentityOp(Operation *op) { return isa(op); } +// HACK: until we are doing buffer allocation we need to pin tied buffers all +// the way up the stream from the outputs. +static void propagateTiedBuffer(BufferSet &bufferSet, Value streamValue, + BufferRange bufferRange) { + Value baseValue = streamValue; + while (auto definingOp = dyn_cast_or_null( + baseValue.getDefiningOp())) { + auto tiedValue = definingOp.getTiedResultOperand(baseValue); + if (!tiedValue) break; + baseValue = tiedValue; + bufferSet.rangeMap[baseValue] = bufferRange; + } +} + // Allocates a buffer for the given stream output value. // |streamValue| is the Value used within the stream region and // |externalValue| is the returned value from the stream region in the parent @@ -110,17 +124,36 @@ static Value allocateOutputBuffer(Value streamValue, Value externalValue, static void allocateOutputBuffers(IREE::Flow::ExStreamFragmentOp streamOp, BufferSet &bufferSet, ConversionPatternRewriter &rewriter) { + auto tiedStreamOp = cast(streamOp.getOperation()); + auto &entryBlock = streamOp.body().front(); + // Allocate output buffers and replace the original uses with the buffers. auto returnOp = cast(streamOp.body().front().back()); for (auto result : llvm::enumerate(streamOp.getResults())) { auto streamValue = returnOp.getOperand(result.index()); auto externalValue = result.value(); - auto buffer = allocateOutputBuffer(streamValue, externalValue, - bufferSet.allocator, rewriter); - auto bufferRange = BufferRange{buffer}; + + // Tied results reuse their operand buffer. + BufferRange bufferRange; + auto tiedOperandIndex = + tiedStreamOp.getTiedResultOperandIndex(result.index()); + if (tiedOperandIndex.hasValue()) { + LLVM_DEBUG(llvm::dbgs() + << " -- REUSING TIED OPERAND(" + << tiedOperandIndex.getValue() << ") BUFFER FOR STREAM RESULT(" + << result.index() << "): " << streamOp << "\n"); + auto operand = entryBlock.getArgument(tiedOperandIndex.getValue()); + bufferRange = bufferSet.rangeMap[operand]; + } else { + auto buffer = allocateOutputBuffer(streamValue, externalValue, + bufferSet.allocator, rewriter); + bufferRange = BufferRange{buffer}; + } + assert(bufferRange.buffer); + bufferSet.outputBuffers.push_back(bufferRange.buffer); bufferSet.rangeMap[externalValue] = bufferRange; bufferSet.rangeMap[streamValue] = bufferRange; - bufferSet.outputBuffers.push_back(buffer); + propagateTiedBuffer(bufferSet, streamValue, bufferRange); } } @@ -221,20 +254,41 @@ static void allocateTransientBuffers(IREE::Flow::ExStreamFragmentOp streamOp, } for (auto &op : streamOp.body().front()) { if (isNoOp(&op) || isIdentityOp(&op)) continue; + auto tiedOp = dyn_cast(op); for (auto it : llvm::enumerate(op.getResults())) { auto result = it.value(); if (!result.getType().isa()) continue; + // If the result is an output buffer we can just use that directly. if (bufferSet.rangeMap[result].buffer) { LLVM_DEBUG(llvm::dbgs() << " -- SKIP ALREADY SET BUFFER RESULT(" << it.index() << "): " << op << "\n"); continue; } + + // Tied results reuse their operand buffer. + if (tiedOp) { + auto tiedOperandIndex = tiedOp.getTiedResultOperandIndex(it.index()); + if (tiedOperandIndex.hasValue()) { + LLVM_DEBUG(llvm::dbgs() + << " -- REUSING TIED OPERAND(" + << tiedOperandIndex.getValue() << ") BUFFER FOR RESULT(" + << it.index() << "): " << op << "\n"); + auto operand = op.getOperand(tiedOperandIndex.getValue()); + auto operandBufferRange = bufferSet.rangeMap[operand]; + assert(operandBufferRange.buffer); + bufferSet.rangeMap[result] = operandBufferRange; + continue; + } + } + LLVM_DEBUG(llvm::dbgs() << " -- ALLOCATE BUFFER FOR RESULT(" << it.index() << "): " << op << "\n"); auto buffer = allocateTransientBuffer(result, bufferSet.allocator, rewriter); - bufferSet.rangeMap[result] = BufferRange{buffer}; + auto bufferRange = BufferRange{buffer}; + bufferSet.rangeMap[result] = bufferRange; + propagateTiedBuffer(bufferSet, result, bufferRange); } } while (propagateIdentityBuffers()) { @@ -450,13 +504,44 @@ static LogicalResult recordDispatch(Value device, Value commandBuffer, return success(); } +static LogicalResult recordTensorClone(Value device, Value commandBuffer, + IREE::Flow::TensorCloneOp &cloneOp, + BufferSet &bufferSet, + ConversionPatternRewriter &rewriter) { + auto &operandBuffer = bufferSet.rangeMap[cloneOp.operand()]; + auto &resultBuffer = bufferSet.rangeMap[cloneOp.result()]; + + auto operand = IREE::HAL::TensorRewriteAdaptor::getChecked( + cloneOp.getLoc(), cloneOp.operand(), operandBuffer.buffer, rewriter); + auto result = IREE::HAL::TensorRewriteAdaptor::getChecked( + cloneOp.getLoc(), cloneOp.result(), resultBuffer.buffer, rewriter); + if (!operand.hasValue() || !result.hasValue()) { + return cloneOp.emitOpError() + << "cannot create adaptors for tensor clone operands/results"; + } + + auto zeroOffset = + rewriter.createOrFold(cloneOp.getLoc(), 0); + auto byteLength = operand->getByteLength(); + if (!byteLength) return failure(); + + rewriter.create( + cloneOp.getLoc(), commandBuffer, operand->getBuffer(), zeroOffset, + result->getBuffer(), zeroOffset, byteLength); + + // Full barriers for now as we aren't scheduling things. + // TODO(benvanik): don't add at the end of the command buffer (we could + // also do a canonicalization step that removed trailing barriers). + recordFullExecutionBarrier(commandBuffer, cloneOp.getLoc(), rewriter); + return success(); +} + static LogicalResult recordTensorUpdate(Value device, Value commandBuffer, IREE::Flow::TensorUpdateOp &updateOp, BufferSet &bufferSet, ConversionPatternRewriter &rewriter) { auto &updateBuffer = bufferSet.rangeMap[updateOp.update()]; auto &targetBuffer = bufferSet.rangeMap[updateOp.target()]; - auto &resultBuffer = bufferSet.rangeMap[updateOp.result()]; // TODO(benvanik): use something other than the BufferRange::buffer? // This may require us to subview the buffer first. @@ -464,9 +549,7 @@ static LogicalResult recordTensorUpdate(Value device, Value commandBuffer, updateOp.getLoc(), updateOp.update(), updateBuffer.buffer, rewriter); auto target = IREE::HAL::TensorRewriteAdaptor::getChecked( updateOp.getLoc(), updateOp.target(), targetBuffer.buffer, rewriter); - auto result = IREE::HAL::TensorRewriteAdaptor::getChecked( - updateOp.getLoc(), updateOp.result(), resultBuffer.buffer, rewriter); - if (!update.hasValue() || !target.hasValue() || !result.hasValue()) { + if (!update.hasValue() || !target.hasValue()) { return updateOp.emitOpError() << "cannot create adaptors for tensor update operands/results"; } @@ -484,18 +567,10 @@ static LogicalResult recordTensorUpdate(Value device, Value commandBuffer, target->computeRange(startIndices, *update->getShapeDims()); if (!targetRange) return failure(); - // TODO(benvanik): actual buffer allocation so we aren't doing this copy. - auto targetByteLength = target->getByteLength(); - if (!targetByteLength) return failure(); - - rewriter.create( - updateOp.getLoc(), commandBuffer, target->getBuffer(), zeroOffset, - result->getBuffer(), zeroOffset, targetByteLength); // TODO(benvanik): slice left/mid/right, but really just don't do this. - recordFullExecutionBarrier(commandBuffer, updateOp.getLoc(), rewriter); rewriter.create( updateOp.getLoc(), commandBuffer, update->getBuffer(), zeroOffset, - result->getBuffer(), targetRange->offset, targetRange->length); + target->getBuffer(), targetRange->offset, targetRange->length); // Full barriers for now as we aren't scheduling things. // TODO(benvanik): don't add at the end of the command buffer (we could @@ -514,6 +589,11 @@ static LogicalResult recordStreamCommands(Value device, Value commandBuffer, rewriter))) { return failure(); } + } else if (auto cloneOp = dyn_cast(op)) { + if (failed(recordTensorClone(device, commandBuffer, cloneOp, bufferSet, + rewriter))) { + return failure(); + } } else if (auto updateOp = dyn_cast(op)) { if (failed(recordTensorUpdate(device, commandBuffer, updateOp, bufferSet, rewriter))) { diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir index 3ba8d894906e..54d17cd6a8c4 100644 --- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir +++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir @@ -55,12 +55,11 @@ func @tensorUpdate(%arg0 : tensor<1x1x10xf32>, %arg1 : tensor<5x1x10xf32>) -> te // CHECK-NEXT: hal.command_buffer.begin %[[CMD]] %0 = flow.ex.stream.fragment(%arg0, %arg1, %c4, %c1) : (tensor<1x1x10xf32>, tensor<5x1x10xf32>, index, index) -> tensor<5x1x10xf32> = (%arg2: tensor<1x1x10xf32>, %arg3: tensor<5x1x10xf32>, %arg4: index, %arg5: index) -> tensor<5x1x10xf32> { - // TODO(laurenzo): Update these checks to be more precise. The regexes can - // match too much, masking issues. // CHECK-NEXT: hal.command_buffer.copy_buffer %[[CMD]], %[[TBUF]], %c0, %[[RET_BUF]], %c0, %c200 // CHECK: hal.command_buffer.execution_barrier + %clone = flow.tensor.clone %arg3 : tensor<5x1x10xf32> // CHECK-NEXT: hal.command_buffer.copy_buffer %[[CMD]], %[[UBUF]], %c0, %[[RET_BUF]], %c204, %c40 - %1 = flow.tensor.update %arg2, %arg3[%arg4, %arg5, %arg5] : tensor<1x1x10xf32> -> tensor<5x1x10xf32> + %1 = flow.tensor.update %arg2, %clone[%arg4, %arg5, %arg5] : tensor<1x1x10xf32> -> tensor<5x1x10xf32> flow.return %1 : tensor<5x1x10xf32> } // CHECK: hal.command_buffer.end %[[CMD]] @@ -118,7 +117,7 @@ hal.executable @ex attributes {sym_visibility = "private"} { hal.executable.entry_point @entry attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input<7x4x24xf32>, !flow.dispatch.output<4x7x1024xf32>) -> () + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor) -> () } module {} } @@ -152,7 +151,7 @@ hal.executable @ex attributes {sym_visibility = "private"} { hal.executable.entry_point @entry attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input<7x?x24x?xf32>, !flow.dispatch.output, index, index, index, index) -> () + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, index, index, index, index) -> () } module {} } diff --git a/iree/compiler/Dialect/HAL/Target/CUDA/test/smoketest.mlir b/iree/compiler/Dialect/HAL/Target/CUDA/test/smoketest.mlir index de75c9d1f4cc..26c079652628 100644 --- a/iree/compiler/Dialect/HAL/Target/CUDA/test/smoketest.mlir +++ b/iree/compiler/Dialect/HAL/Target/CUDA/test/smoketest.mlir @@ -7,16 +7,16 @@ flow.executable @add_dispatch_0 { workgroup_rank = 3 : index } module { - func @add_dispatch_0(%arg0: !flow.dispatch.input<16xf32>, %arg1: !flow.dispatch.input<16xf32>, %arg2: !flow.dispatch.output<16xf32>) { + func @add_dispatch_0(%arg0: !flow.dispatch.tensor, %arg1: !flow.dispatch.tensor, %arg2: !flow.dispatch.tensor) { %0 = linalg.init_tensor [16] : tensor<16xf32> - %1 = flow.dispatch.input.load %arg0 : !flow.dispatch.input<16xf32> -> tensor<16xf32> - %2 = flow.dispatch.input.load %arg1 : !flow.dispatch.input<16xf32> -> tensor<16xf32> + %1 = flow.dispatch.tensor.load %arg0 : !flow.dispatch.tensor -> tensor<16xf32> + %2 = flow.dispatch.tensor.load %arg1 : !flow.dispatch.tensor -> tensor<16xf32> %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%1, %2 : tensor<16xf32>, tensor<16xf32>) outs(%0 : tensor<16xf32>) { ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors %4 = addf %arg3, %arg4 : f32 linalg.yield %4 : f32 } -> tensor<16xf32> - flow.dispatch.output.store %3, %arg2 : tensor<16xf32> -> !flow.dispatch.output<16xf32> + flow.dispatch.tensor.store %3, %arg2 : tensor<16xf32> -> !flow.dispatch.tensor return } } diff --git a/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp b/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp index 404969b11c06..69d4da0b2a61 100644 --- a/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp +++ b/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp @@ -70,29 +70,43 @@ static llvm::Optional declareInterfaceIO( // NOTE: we assume right now that all entry points have the same signature. // TODO(benvanik): replace when we have descriptor sets in the HAL IR. auto anyFuncOp = entryFuncOps.front(); - int binding = 0; + int nextBindingOrdinal = 0; int pushConstantCount = 0; - int argOrdinal = 0; - int retOrdinal = 0; for (auto inputType : llvm::enumerate(anyFuncOp.getType().getInputs())) { if (inputType.value().isa()) { + int bindingOrdinal = nextBindingOrdinal++; auto bindingName = "arg" + std::to_string(inputType.index()); interfaceBuilder.create( - interfaceLoc, bindingName, /*set=*/0, /*binding=*/binding++, - IREE::HAL::DescriptorType::StorageBuffer, - IREE::HAL::MemoryAccessBitfield::Read); - } else if (inputType.value().isa()) { - auto bindingName = "arg" + std::to_string(argOrdinal++); - interfaceBuilder.create( - interfaceLoc, bindingName, /*set=*/0, /*binding=*/binding++, + interfaceLoc, bindingName, /*set=*/0, /*binding=*/bindingOrdinal, IREE::HAL::DescriptorType::StorageBuffer, IREE::HAL::MemoryAccessBitfield::Read); - } else if (inputType.value().isa()) { - auto bindingName = "ret" + std::to_string(retOrdinal++); + } else if (auto tensorType = + inputType.value() + .dyn_cast()) { + StringRef prefix; + IREE::HAL::MemoryAccessBitfield memoryAccess = + IREE::HAL::MemoryAccessBitfield::None; + switch (tensorType.getAccess()) { + case IREE::Flow::TensorAccess::ReadOnly: + prefix = "ro"; + memoryAccess = IREE::HAL::MemoryAccessBitfield::Read; + break; + case IREE::Flow::TensorAccess::ReadWrite: + prefix = "rw"; + memoryAccess = IREE::HAL::MemoryAccessBitfield::Read | + IREE::HAL::MemoryAccessBitfield::Write; + break; + case IREE::Flow::TensorAccess::WriteOnly: + prefix = "wo"; + memoryAccess = IREE::HAL::MemoryAccessBitfield::DiscardWrite; + break; + } + int bindingOrdinal = nextBindingOrdinal++; + std::string bindingName = + std::string(prefix) + std::to_string(bindingOrdinal); interfaceBuilder.create( - interfaceLoc, bindingName, /*set=*/0, /*binding=*/binding++, - IREE::HAL::DescriptorType::StorageBuffer, - IREE::HAL::MemoryAccessBitfield::DiscardWrite); + interfaceLoc, bindingName, /*set=*/0, /*binding=*/bindingOrdinal, + IREE::HAL::DescriptorType::StorageBuffer, memoryAccess); } else if (auto indexType = inputType.value().dyn_cast()) { ++pushConstantCount; } else if (auto integerType = inputType.value().dyn_cast()) { @@ -113,10 +127,11 @@ static llvm::Optional declareInterfaceIO( } } for (auto outputType : llvm::enumerate(anyFuncOp.getType().getResults())) { + int bindingOrdinal = nextBindingOrdinal++; auto bindingName = "ret" + std::to_string(outputType.index()); if (outputType.value().isa()) { interfaceBuilder.create( - interfaceLoc, bindingName, /*set=*/0, /*binding=*/binding++, + interfaceLoc, bindingName, /*set=*/0, /*binding=*/bindingOrdinal, IREE::HAL::DescriptorType::StorageBuffer, IREE::HAL::MemoryAccessBitfield::DiscardWrite); } else { diff --git a/iree/compiler/Dialect/HAL/Transforms/Passes.cpp b/iree/compiler/Dialect/HAL/Transforms/Passes.cpp index 29f64a461a0f..a0dc69f0ce26 100644 --- a/iree/compiler/Dialect/HAL/Transforms/Passes.cpp +++ b/iree/compiler/Dialect/HAL/Transforms/Passes.cpp @@ -137,6 +137,9 @@ void buildHALTransformPassPipeline(OpPassManager &passManager, // if we serialized things. passManager.addPass(createSymbolDCEPass()); } + + // Final cleanup of IR; cleans up things left behind by CSE/DCE above. + passManager.addNestedPass(createCanonicalizerPass()); } void buildHALTransformPassPipeline(OpPassManager &passManager, diff --git a/iree/compiler/Dialect/HAL/Transforms/test/materialize_interfaces.mlir b/iree/compiler/Dialect/HAL/Transforms/test/materialize_interfaces.mlir index 1f98f6e700a6..b90afde7bafc 100644 --- a/iree/compiler/Dialect/HAL/Transforms/test/materialize_interfaces.mlir +++ b/iree/compiler/Dialect/HAL/Transforms/test/materialize_interfaces.mlir @@ -112,15 +112,15 @@ flow.executable @shaped_dispatch { // CHECK-LABEL: hal.executable @static_tiled_dispatch // CHECK-NEXT: hal.interface @legacy_io { -// CHECK-NEXT: hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" -// CHECK-NEXT: hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard" +// CHECK-NEXT: hal.interface.binding @ro0, set=0, binding=0, type="StorageBuffer", access="Read" +// CHECK-NEXT: hal.interface.binding @wo1, set=0, binding=1, type="StorageBuffer", access="Write|Discard" // CHECK-NEXT: } flow.executable @static_tiled_dispatch { // CHECK-NEXT: hal.executable.target @vmla, filter="vmla" { // CHECK-NEXT: hal.executable.entry_point @entry attributes { // CHECK-SAME: interface = @legacy_io, // CHECK-SAME: ordinal = 0 : i32, - // CHECK-SAME: signature = (!flow.dispatch.input<8x4xf32>, !flow.dispatch.output<4x8xf32>) -> () + // CHECK-SAME: signature = (!flow.dispatch.tensor, !flow.dispatch.tensor) -> () // CHECK-SAME: } flow.dispatch.entry @entry attributes { signature = (tensor<8x4xf32>) -> tensor<4x8xf32>, @@ -129,17 +129,17 @@ flow.executable @static_tiled_dispatch { // CHECK-NEXT: module { module { // CHECK-NEXT: func @entry() { - func @entry(%arg: !flow.dispatch.input<8x4xf32>, %ret: !flow.dispatch.output<4x8xf32>) { + func @entry(%arg: !flow.dispatch.tensor, %ret: !flow.dispatch.tensor) { // CHECK-NEXT: %c0 = constant 0 : index - // CHECK-NEXT: %[[ARG:.+]] = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.input<8x4xf32> - // CHECK-NEXT: %[[RET:.+]] = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.output<4x8xf32> + // CHECK-NEXT: %[[ARG:.+]] = hal.interface.binding.subspan @legacy_io::@ro0[%c0] : !flow.dispatch.tensor + // CHECK-NEXT: %[[RET:.+]] = hal.interface.binding.subspan @legacy_io::@wo1[%c0] : !flow.dispatch.tensor - // CHECK-NEXT: %[[ARG_TILE:.+]] = flow.dispatch.input.load %[[ARG]] - %arg_tile = flow.dispatch.input.load %arg : !flow.dispatch.input<8x4xf32> -> tensor<8x4xf32> + // CHECK-NEXT: %[[ARG_TILE:.+]] = flow.dispatch.tensor.load %[[ARG]] + %arg_tile = flow.dispatch.tensor.load %arg : !flow.dispatch.tensor -> tensor<8x4xf32> // CHECK-NEXT: %[[RET_TILE:.+]] = "test.sink"(%[[ARG_TILE]]) %ret_tile = "test.sink"(%arg_tile) : (tensor<8x4xf32>) -> tensor<4x8xf32> - // CHECK-NEXT: flow.dispatch.output.store %[[RET_TILE]], %[[RET]] - flow.dispatch.output.store %ret_tile, %ret : tensor<4x8xf32> -> !flow.dispatch.output<4x8xf32> + // CHECK-NEXT: flow.dispatch.tensor.store %[[RET_TILE]], %[[RET]] + flow.dispatch.tensor.store %ret_tile, %ret : tensor<4x8xf32> -> !flow.dispatch.tensor return } } @@ -149,15 +149,15 @@ flow.executable @static_tiled_dispatch { // CHECK-LABEL: hal.executable @dynamic_tiled_dispatch // CHECK-NEXT: hal.interface @legacy_io attributes {push_constants = 4 : i32} { -// CHECK-NEXT: hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" -// CHECK-NEXT: hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard" +// CHECK-NEXT: hal.interface.binding @ro0, set=0, binding=0, type="StorageBuffer", access="Read" +// CHECK-NEXT: hal.interface.binding @wo1, set=0, binding=1, type="StorageBuffer", access="Write|Discard" // CHECK-NEXT: } flow.executable @dynamic_tiled_dispatch { // CHECK-NEXT: hal.executable.target @vmla, filter="vmla" { // CHECK-NEXT: hal.executable.entry_point @entry attributes { // CHECK-SAME: interface = @legacy_io, // CHECK-SAME: ordinal = 0 : i32, - // CHECK-SAME: signature = (!flow.dispatch.input<7x?x24x?xf32>, !flow.dispatch.output, index, index, index, index) -> () + // CHECK-SAME: signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, index, index, index, index) -> () // CHECK-SAME: } flow.dispatch.entry @entry attributes { signature = (tensor<7x?x24x?xf32>) -> tensor, @@ -168,10 +168,10 @@ flow.executable @dynamic_tiled_dispatch { // CHECK-NEXT: func @entry() { func @entry( // CHECK-NEXT: %c0 = constant 0 : index - // CHECK-DAG: %[[ARG:.+]] = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.input<7x?x24x?xf32> - %arg: !flow.dispatch.input<7x?x24x?xf32>, - // CHECK-DAG: %[[RET:.+]] = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.output - %ret: !flow.dispatch.output, + // CHECK-DAG: %[[ARG:.+]] = hal.interface.binding.subspan @legacy_io::@ro0[%c0] : !flow.dispatch.tensor + %arg: !flow.dispatch.tensor, + // CHECK-DAG: %[[RET:.+]] = hal.interface.binding.subspan @legacy_io::@wo1[%c0] : !flow.dispatch.tensor + %ret: !flow.dispatch.tensor, // CHECK-DAG: %[[ARG_DIM1:.+]] = hal.interface.load.constant offset = 0 : index %arg_dim1: index, // CHECK-DAG: %[[ARG_DIM3:.+]] = hal.interface.load.constant offset = 1 : index @@ -184,17 +184,17 @@ flow.executable @dynamic_tiled_dispatch { // CHECK-NEXT: %[[ARG_SHAPE:.+]] = shapex.make_ranked_shape %[[ARG_DIM1]], %[[ARG_DIM3]] %arg_shape = shapex.make_ranked_shape %arg_dim1, %arg_dim3 : (index, index) -> !shapex.ranked_shape<[7,?,24,?]> // CHECK-NEXT: %[[ARG_SHAPED:.+]] = flow.dispatch.tie_shape %[[ARG]], %[[ARG_SHAPE]] - %arg_shaped = flow.dispatch.tie_shape %arg, %arg_shape : (!flow.dispatch.input<7x?x24x?xf32>, !shapex.ranked_shape<[7,?,24,?]>) -> !flow.dispatch.input<7x?x24x?xf32> + %arg_shaped = flow.dispatch.tie_shape %arg, %arg_shape : (!flow.dispatch.tensor, !shapex.ranked_shape<[7,?,24,?]>) -> !flow.dispatch.tensor // CHECK-NEXT: %[[RET_SHAPE:.+]] = shapex.make_ranked_shape %[[RET_DIM0]], %[[RET_DIM1]] %ret_shape = shapex.make_ranked_shape %ret_dim0, %ret_dim1 : (index, index) -> !shapex.ranked_shape<[?,?,1024]> // CHECK-NEXT: %[[RET_SHAPED:.+]] = flow.dispatch.tie_shape %[[RET]], %[[RET_SHAPE]] - %ret_shaped = flow.dispatch.tie_shape %ret, %ret_shape : (!flow.dispatch.output, !shapex.ranked_shape<[?,?,1024]>) -> !flow.dispatch.output - // CHECK-NEXT: %[[ARG_TILE:.+]] = flow.dispatch.input.load %[[ARG_SHAPED]] - %arg_tile = flow.dispatch.input.load %arg_shaped : !flow.dispatch.input<7x?x24x?xf32> -> tensor<7x?x24x?xf32> + %ret_shaped = flow.dispatch.tie_shape %ret, %ret_shape : (!flow.dispatch.tensor, !shapex.ranked_shape<[?,?,1024]>) -> !flow.dispatch.tensor + // CHECK-NEXT: %[[ARG_TILE:.+]] = flow.dispatch.tensor.load %[[ARG_SHAPED]] + %arg_tile = flow.dispatch.tensor.load %arg_shaped : !flow.dispatch.tensor -> tensor<7x?x24x?xf32> // CHECK-NEXT: %[[RET_TILE:.+]] = "test.tile_math"(%[[ARG_TILE]]) %ret_tile = "test.tile_math"(%arg_tile) : (tensor<7x?x24x?xf32>) -> tensor - // CHECK-NEXT: flow.dispatch.output.store %[[RET_TILE]], %[[RET_SHAPED]] - flow.dispatch.output.store %ret_tile, %ret_shaped : tensor -> !flow.dispatch.output + // CHECK-NEXT: flow.dispatch.tensor.store %[[RET_TILE]], %[[RET_SHAPED]] + flow.dispatch.tensor.store %ret_tile, %ret_shaped : tensor -> !flow.dispatch.tensor return } } @@ -204,15 +204,15 @@ flow.executable @dynamic_tiled_dispatch { // CHECK-LABEL: hal.executable @workgroup_infos // CHECK-NEXT: hal.interface @legacy_io { -// CHECK-NEXT: hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" -// CHECK-NEXT: hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard" +// CHECK-NEXT: hal.interface.binding @ro0, set=0, binding=0, type="StorageBuffer", access="Read" +// CHECK-NEXT: hal.interface.binding @wo1, set=0, binding=1, type="StorageBuffer", access="Write|Discard" // CHECK-NEXT: } flow.executable @workgroup_infos { // CHECK-NEXT: hal.executable.target @vmla, filter="vmla" { // CHECK-NEXT: hal.executable.entry_point @entry attributes { // CHECK-SAME: interface = @legacy_io, // CHECK-SAME: ordinal = 0 : i32, - // CHECK-SAME: signature = (!flow.dispatch.input<8x4xf32>, !flow.dispatch.output<4x8xf32>) -> () + // CHECK-SAME: signature = (!flow.dispatch.tensor, !flow.dispatch.tensor) -> () // CHECK-SAME: } flow.dispatch.entry @entry attributes { signature = (tensor<8x4xf32>) -> tensor<4x8xf32>, @@ -221,7 +221,7 @@ flow.executable @workgroup_infos { // CHECK-NEXT: module { module { // CHECK-NEXT: func @entry() { - func @entry(%arg: !flow.dispatch.input<8x4xf32>, %ret: !flow.dispatch.output<4x8xf32>) { + func @entry(%arg: !flow.dispatch.tensor, %ret: !flow.dispatch.tensor) { // CHECK-DAG: %[[WORKGROUP_ID_X:.+]] = hal.interface.workgroup.id[0] : index %id_x = flow.dispatch.workgroup.id[0] : index // CHECK-DAG: %[[WORKGROUP_ID_Y:.+]] = hal.interface.workgroup.id[1] : index diff --git a/iree/compiler/Dialect/IREE/IR/BUILD b/iree/compiler/Dialect/IREE/IR/BUILD index d9da85ae50b4..455f4613bfae 100644 --- a/iree/compiler/Dialect/IREE/IR/BUILD +++ b/iree/compiler/Dialect/IREE/IR/BUILD @@ -29,6 +29,7 @@ filegroup( srcs = enforce_glob( [ "IREEBase.td", + "IREEInterfaces.td", "IREEOps.td", ], include = ["*.td"], @@ -39,18 +40,21 @@ cc_library( name = "IR", srcs = [ "IREEDialect.cpp", + "IREEOpInterfaces.cpp.inc", "IREEOps.cpp", "IREEOps.cpp.inc", "IREETypes.cpp", ], hdrs = [ "IREEDialect.h", + "IREEOpInterfaces.h.inc", "IREEOps.h", "IREEOps.h.inc", "IREETraits.h", "IREETypes.h", ], deps = [ + ":IREEInterfacesGen", ":IREEOpsGen", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", @@ -61,6 +65,20 @@ cc_library( ], ) +gentbl( + name = "IREEInterfacesGen", + tbl_outs = [ + ("-gen-op-interface-decls", "IREEOpInterfaces.h.inc"), + ("-gen-op-interface-defs", "IREEOpInterfaces.cpp.inc"), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "IREEInterfaces.td", + td_srcs = [ + ":td_files", + "@llvm-project//mlir:OpBaseTdFiles", + ], +) + gentbl( name = "IREEOpsGen", tbl_outs = [ diff --git a/iree/compiler/Dialect/IREE/IR/CMakeLists.txt b/iree/compiler/Dialect/IREE/IR/CMakeLists.txt index 1758d1cd6254..772034ba6054 100644 --- a/iree/compiler/Dialect/IREE/IR/CMakeLists.txt +++ b/iree/compiler/Dialect/IREE/IR/CMakeLists.txt @@ -15,12 +15,14 @@ iree_cc_library( IR HDRS "IREEDialect.h" + "IREEOpInterfaces.h.inc" "IREEOps.h" "IREEOps.h.inc" "IREETraits.h" "IREETypes.h" SRCS "IREEDialect.cpp" + "IREEOpInterfaces.cpp.inc" "IREEOps.cpp" "IREEOps.cpp.inc" "IREETypes.cpp" @@ -34,6 +36,16 @@ iree_cc_library( PUBLIC ) +iree_tablegen_library( + NAME + IREEInterfacesGen + TD_FILE + "IREEInterfaces.td" + OUTS + -gen-op-interface-decls IREEOpInterfaces.h.inc + -gen-op-interface-defs IREEOpInterfaces.cpp.inc +) + iree_tablegen_library( NAME IREEOpsGen diff --git a/iree/compiler/Dialect/IREE/IR/IREEBase.td b/iree/compiler/Dialect/IREE/IR/IREEBase.td index 0d2b1bbba123..c4906dab4fa4 100644 --- a/iree/compiler/Dialect/IREE/IR/IREEBase.td +++ b/iree/compiler/Dialect/IREE/IR/IREEBase.td @@ -53,6 +53,11 @@ class IREE_IndexAttrBase : } def IREE_IndexAttr : IREE_IndexAttrBase<"size_t">; +def IREE_TiedOpStorageAttr : + TypedArrayAttrBase { + let constBuilderCall = "$_builder.getI64ArrayAttr($0)"; +} + //===----------------------------------------------------------------------===// // Status codes //===----------------------------------------------------------------------===// diff --git a/iree/compiler/Dialect/IREE/IR/IREEInterfaces.td b/iree/compiler/Dialect/IREE/IR/IREEInterfaces.td new file mode 100644 index 000000000000..2e2a65e3df79 --- /dev/null +++ b/iree/compiler/Dialect/IREE/IR/IREEInterfaces.td @@ -0,0 +1,166 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_DIALECT_IREE_INTERFACES +#define IREE_DIALECT_IREE_INTERFACES + +include "iree/compiler/Dialect/IREE/IR/IREEBase.td" + +//===----------------------------------------------------------------------===// +// IREE::TiedOpInterface +//===----------------------------------------------------------------------===// + +def IREE_TiedOpInterface : OpInterface<"TiedOpInterface"> { + let description = [{ + An operation that "ties" one or more results to its operands indicating + that the result is directly related to the operand in an operation-defined + way. Results are still SSA values distinct from the operands and the tie is + strictly a relationship relevant to transformations and not something that + modifies IR definitions. + + Example: + An operation on tensors that wants to indicate that the storage for a + result should alias the storage for an operand, performing an "in-place" + operation. Since tensors are still used there is no hard requirement that + uses of the result SSA value alias the operand; a copy may still be + introduced. + + See: flow.dispatch.workgroups + + Example: + An operation on buffers that wants to encode activity on the buffer in IR + (such as a barrier, a transfer operation, etc) such that the SSA use-def + chain is representing the state of the buffer at various points in time + but that the underlying buffers are all tied together. + + See: hal.stream.barrier + + The default implementations use an attribute on the op to store the + relationship: + `OptionalAttr:$tied_operands` + }]; + + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Returns the set of operands that results may be tied to as an + (index, length) pair ala getODSOperandIndexAndLength. + + By default assumes all operands may be tied. If an op treats some + operands as special then the op can override this and specify only the + ones it will tie. For example, a cond_branch that has a condition + operand as well as the successor operands would return only the range + of successor operands. + }], + /*retTy=*/"std::pair", + /*methodName=*/"getTiedOperandsIndexAndLength", (ins), + /*args=*/[{}], + /*defaultImplementation=*/[{ + return {0, $_op.getNumOperands()}; + }] + >, + InterfaceMethod< + /*desc=*/[{ + Walks up the SSA use-def chain to find the first defined value reachable + from the given value by traversing tied ops. The returned value may be + in another block if that block dominates the one the result is defined + in. + + Note that the returned value may be a block argument and have no + defining op, and the search will not continue past branches. + If the result is untied then the result itself is returned. + }], + /*retTy=*/"Value", + /*methodName=*/"getTiedResult", + /*args=*/(ins "unsigned":$resultIndex), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + return IREE::TiedOpInterface::findTiedBaseValue($_op.getResult(resultIndex)); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Returns the operand tied to the given result of the op or nullptr if + none. + }], + /*retTy=*/"Value", + /*methodName=*/"getTiedResultOperand", + /*args=*/(ins "Value":$result), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + auto resultIndex = result.cast().getResultNumber(); + auto operandIndex = IREE::detail::getTiedResultOperandIndex($_op, resultIndex); + return operandIndex.hasValue() ? + $_op.getOperand(operandIndex.getValue()) : + nullptr; + }] + >, + InterfaceMethod< + /*desc=*/[{ + Returns the operand index tied to the given result index, if any. + }], + /*retTy=*/"::llvm::Optional", + /*methodName=*/"getTiedResultOperandIndex", + /*args=*/(ins "unsigned":$resultIndex), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + return IREE::detail::getTiedResultOperandIndex($_op, resultIndex); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Sets the operand index tied to the given result index, if any. + }], + /*retTy=*/"void", + /*methodName=*/"setTiedResultOperandIndex", + /*args=*/(ins "unsigned":$resultIndex, + "::llvm::Optional":$operandIndex), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + return IREE::detail::setTiedResultOperandIndex($_op, resultIndex, operandIndex); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Returns an array containing the tied result operand indices with -1 + indicating that a result is not tied. + }], + /*retTy=*/"SmallVector", + /*methodName=*/"getTiedResultOperandIndices", + /*args=*/(ins), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + return IREE::detail::getTiedResultOperandIndices($_op); + }] + >, + ]; + + let extraClassDeclaration = [{ + static StringRef getStorageAttrName() { return "tied_operands"; } + + // Indicates that a result is not tied to any operand. + static constexpr int64_t kUntiedIndex = -1; + + // Walks the SSA use-def chain to find the first defined value reachable + // from the given value by traversing tied ops. Note that the returned + // value may be a block argument and have no defining op. + static Value findTiedBaseValue(Value derivedValue); + }]; + + let verify = [{ + return IREE::detail::verifyTiedOp($_op); + }]; +} + +#endif // IREE_DIALECT_IREE_INTERFACES diff --git a/iree/compiler/Dialect/IREE/IR/IREETypes.cpp b/iree/compiler/Dialect/IREE/IR/IREETypes.cpp index 745697cac353..c278c5370eb6 100644 --- a/iree/compiler/Dialect/IREE/IR/IREETypes.cpp +++ b/iree/compiler/Dialect/IREE/IR/IREETypes.cpp @@ -17,6 +17,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/TypeSupport.h" +#include "mlir/Interfaces/CastInterfaces.h" namespace mlir { namespace iree_compiler { @@ -103,6 +104,124 @@ PtrType PtrType::getChecked(function_ref emitError, Type PtrType::getTargetType() { return getImpl()->targetType; } +//===----------------------------------------------------------------------===// +// TiedOpInterface +//===----------------------------------------------------------------------===// + +llvm::Optional detail::getTiedResultOperandIndex( + Operation *op, unsigned resultIndex) { + auto storageAttr = + op->getAttrOfType(TiedOpInterface::getStorageAttrName()); + if (!storageAttr) return llvm::None; + auto valueAttrs = storageAttr.getValue(); + if (valueAttrs.empty()) return llvm::None; + int64_t value = valueAttrs[resultIndex].cast().getInt(); + if (value == TiedOpInterface::kUntiedIndex) return llvm::None; + auto tiedOp = cast(op); + unsigned tiedOperandsOffset = tiedOp.getTiedOperandsIndexAndLength().first; + return tiedOperandsOffset + static_cast(value); +} + +void detail::setTiedResultOperandIndex(Operation *op, unsigned resultIndex, + llvm::Optional operandIndex) { + auto indices = getTiedResultOperandIndices(op); + if (indices.empty()) { + indices.resize(op->getNumResults(), TiedOpInterface::kUntiedIndex); + } + indices[resultIndex] = operandIndex.hasValue() + ? operandIndex.getValue() + : TiedOpInterface::kUntiedIndex; + auto indexType = IndexType::get(op->getContext()); + op->setAttr(TiedOpInterface::getStorageAttrName(), + ArrayAttr::get(op->getContext(), + llvm::to_vector<8>(llvm::map_range( + indices, [&](int64_t v) -> Attribute { + return IntegerAttr::get(indexType, v); + })))); +} + +SmallVector detail::getTiedResultOperandIndices(Operation *op) { + SmallVector indices; + auto storageAttr = + op->getAttrOfType(TiedOpInterface::getStorageAttrName()); + if (!storageAttr) return indices; + auto valueAttrs = storageAttr.getValue(); + if (valueAttrs.empty()) return indices; + auto tiedOp = cast(op); + unsigned tiedOperandsOffset = tiedOp.getTiedOperandsIndexAndLength().first; + indices.resize(op->getNumResults()); + for (unsigned i = 0; i < valueAttrs.size(); ++i) { + int64_t index = valueAttrs[i].cast().getInt(); + indices[i] = index != TiedOpInterface::kUntiedIndex + ? tiedOperandsOffset + index + : TiedOpInterface::kUntiedIndex; + } + return indices; +} + +Value TiedOpInterface::findTiedBaseValue(Value derivedValue) { + Value baseValue = derivedValue; + while (auto definingOp = + dyn_cast_or_null(baseValue.getDefiningOp())) { + auto tiedValue = definingOp.getTiedResultOperand(baseValue); + if (!tiedValue) break; + baseValue = tiedValue; + } + return baseValue; +} + +LogicalResult detail::verifyTiedOp(TiedOpInterface tiedOp) { + unsigned tiedOperandsOffset = tiedOp.getTiedOperandsIndexAndLength().first; + auto storageAttr = + tiedOp->getAttrOfType(TiedOpInterface::getStorageAttrName()); + if (!storageAttr || storageAttr.getValue().empty()) { + return success(); + } + auto tiedOperandIndices = storageAttr.getValue(); + if (tiedOperandIndices.size() != tiedOp->getNumResults()) { + return tiedOp.emitError("op results/tied operand indices mismatch"); + } + for (unsigned resultIndex = 0; resultIndex < tiedOp->getNumResults(); + ++resultIndex) { + int64_t tiedOperandIndex = + tiedOperandIndices[resultIndex].cast().getInt(); + if (tiedOperandIndex < 0) continue; + auto operandType = + tiedOp->getOperand(tiedOperandsOffset + tiedOperandIndex).getType(); + auto resultType = tiedOp->getResult(resultIndex).getType(); + if (operandType != resultType) { + return tiedOp.emitError( + "tied operand and result type mismatch; operand has ") + << operandType << " and result has " << resultType; + } + } + return success(); +} + +void excludeTiedOperandAndResultIndices( + ArrayRef excludedOperandIndices, + ArrayRef excludedResultIndices, + SmallVector &tiedOperandIndices) { + SmallVector oldTiedOperandIndices = tiedOperandIndices; + tiedOperandIndices.clear(); + for (auto it : llvm::enumerate(oldTiedOperandIndices)) { + unsigned resultIndex = it.index(); + if (!llvm::count(excludedResultIndices, resultIndex)) { + continue; // result removed + } + int64_t tiedOperandIndex = it.value(); + if (tiedOperandIndex != TiedOpInterface::kUntiedIndex) { + if (!llvm::count(excludedOperandIndices, tiedOperandIndex)) { + tiedOperandIndex = TiedOpInterface::kUntiedIndex; // operand removed + } + } + tiedOperandIndices.push_back(tiedOperandIndex); + } +} + +// At the end so it can use functions above: +#include "iree/compiler/Dialect/IREE/IR/IREEOpInterfaces.cpp.inc" + } // namespace IREE } // namespace iree_compiler } // namespace mlir diff --git a/iree/compiler/Dialect/IREE/IR/IREETypes.h b/iree/compiler/Dialect/IREE/IR/IREETypes.h index 3b55df45480b..73e81305bdec 100644 --- a/iree/compiler/Dialect/IREE/IR/IREETypes.h +++ b/iree/compiler/Dialect/IREE/IR/IREETypes.h @@ -18,6 +18,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Location.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/IR/TypeSupport.h" #include "mlir/IR/Types.h" @@ -25,10 +26,14 @@ namespace mlir { namespace iree_compiler { namespace IREE { +class TiedOpInterface; + namespace detail { + struct ListTypeStorage; struct PtrTypeStorage; struct RankedShapeTypeStorage; + } // namespace detail // Status code table mapping to iree::StatusCode in the runtime. @@ -111,6 +116,24 @@ class MutableByteBufferType using Base::Base; }; +namespace detail { +llvm::Optional getTiedResultOperandIndex(Operation *op, + unsigned resultIndex); +void setTiedResultOperandIndex(Operation *op, unsigned resultIndex, + llvm::Optional operandIndex); +SmallVector getTiedResultOperandIndices(Operation *op); +LogicalResult verifyTiedOp(TiedOpInterface tiedOp); +} // namespace detail + +// Resets or removes the indices in |tiedOperandIndices| based on the given +// exclusion lists. +void excludeTiedOperandAndResultIndices( + ArrayRef excludedOperandIndices, + ArrayRef excludedResultIndices, + SmallVector &tiedOperandIndices); + +#include "iree/compiler/Dialect/IREE/IR/IREEOpInterfaces.h.inc" + } // namespace IREE } // namespace iree_compiler diff --git a/iree/test/e2e/models/unidirectional_lstm.mlir b/iree/test/e2e/models/unidirectional_lstm.mlir index 4ad0c7935b9e..e72556852eea 100644 --- a/iree/test/e2e/models/unidirectional_lstm.mlir +++ b/iree/test/e2e/models/unidirectional_lstm.mlir @@ -1,8 +1,8 @@ // An example LSTM exported from a python reference model with dummy weights. -// RUN: iree-run-mlir %s -iree-hal-target-backends=vmla -function-input="1x5xf32=[0 1 0 3 4]" -function-input="1x5x2x2xf32=[1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20]" | IreeFileCheck %s --implicit-check-not="[" --implicit-check-not="]" -// RUN: [[ $IREE_LLVMAOT_DISABLE == 1 ]] || (iree-run-mlir %s -iree-hal-target-backends=dylib-llvm-aot -function-input="1x5xf32=[0 1 0 3 4]" -function-input="1x5x2x2xf32=[1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20]" | IreeFileCheck %s --implicit-check-not="[" --implicit-check-not="]") -// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir %s -iree-hal-target-backends=vulkan-spirv -function-input="1x5xf32=[0 1 0 3 4]" -function-input="1x5x2x2xf32=[1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20]" | IreeFileCheck %s --implicit-check-not="[" --implicit-check-not="]") +// RUN: iree-run-mlir %s -iree-hal-target-backends=vmla -function-input="1x5xf32=[0,1,0,3,4]" -function-input="1x5x2x2xf32=[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]" | IreeFileCheck %s --implicit-check-not="[" --implicit-check-not="]" +// RUN: [[ $IREE_LLVMAOT_DISABLE == 1 ]] || (iree-run-mlir %s -iree-hal-target-backends=dylib-llvm-aot -function-input="1x5xf32=[0,1,0,3,4]" -function-input="1x5x2x2xf32=[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]" | IreeFileCheck %s --implicit-check-not="[" --implicit-check-not="]") +// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir %s -iree-hal-target-backends=vulkan-spirv -function-input="1x5xf32=[0,1,0,3,4]" -function-input="1x5x2x2xf32=[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]" | IreeFileCheck %s --implicit-check-not="[" --implicit-check-not="]") // Exported via the XLA HLO Importer // The resulting MLIR was modified by hand by changing all large constants to be From 603e9fb0ae7e6a8e0d666e92667fea4d16870db0 Mon Sep 17 00:00:00 2001 From: "Ahmed S. Taei" Date: Thu, 11 Mar 2021 17:59:13 -0800 Subject: [PATCH 16/17] Use upstream math dialect polynomial approximation pass (#5068) - Note the approximations don't have finite math assumption, so they are on by default. --- iree/compiler/Conversion/LLVMToLLVM/BUILD | 36 ----- .../Conversion/LLVMToLLVM/CMakeLists.txt | 29 ---- .../LLVMToLLVM/FastExpConversion.cpp | 125 ------------------ iree/compiler/Conversion/LLVMToLLVM/Passes.h | 30 ----- iree/compiler/Conversion/LinalgToLLVM/BUILD | 1 - .../Conversion/LinalgToLLVM/CMakeLists.txt | 1 - .../Conversion/LinalgToLLVM/ConvertToLLVM.cpp | 7 + .../Conversion/LinalgToLLVM/Passes.cpp | 12 -- iree/test/e2e/llvm_specific/BUILD | 12 -- iree/test/e2e/llvm_specific/CMakeLists.txt | 13 -- iree/test/e2e/llvm_specific/exponential.mlir | 27 ---- 11 files changed, 7 insertions(+), 286 deletions(-) delete mode 100644 iree/compiler/Conversion/LLVMToLLVM/BUILD delete mode 100644 iree/compiler/Conversion/LLVMToLLVM/CMakeLists.txt delete mode 100644 iree/compiler/Conversion/LLVMToLLVM/FastExpConversion.cpp delete mode 100644 iree/compiler/Conversion/LLVMToLLVM/Passes.h delete mode 100644 iree/test/e2e/llvm_specific/exponential.mlir diff --git a/iree/compiler/Conversion/LLVMToLLVM/BUILD b/iree/compiler/Conversion/LLVMToLLVM/BUILD deleted file mode 100644 index 165664750db2..000000000000 --- a/iree/compiler/Conversion/LLVMToLLVM/BUILD +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package( - default_visibility = ["//visibility:public"], - features = ["layering_check"], - licenses = ["notice"], # Apache 2.0 -) - -cc_library( - name = "LLVMToLLVM", - srcs = [ - "FastExpConversion.cpp", - ], - hdrs = [ - "Passes.h", - ], - deps = [ - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:LLVMDialect", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Transforms", - ], -) diff --git a/iree/compiler/Conversion/LLVMToLLVM/CMakeLists.txt b/iree/compiler/Conversion/LLVMToLLVM/CMakeLists.txt deleted file mode 100644 index 583a8fbb3f90..000000000000 --- a/iree/compiler/Conversion/LLVMToLLVM/CMakeLists.txt +++ /dev/null @@ -1,29 +0,0 @@ -################################################################################ -# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # -# iree/compiler/Conversion/LLVMToLLVM/BUILD # -# # -# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # -# CMake-only content. # -# # -# To disable autogeneration for this file entirely, delete this header. # -################################################################################ - -iree_add_all_subdirs() - -iree_cc_library( - NAME - LLVMToLLVM - HDRS - "Passes.h" - SRCS - "FastExpConversion.cpp" - DEPS - LLVMSupport - MLIRIR - MLIRLLVMIR - MLIRPass - MLIRTransforms - PUBLIC -) - -### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/iree/compiler/Conversion/LLVMToLLVM/FastExpConversion.cpp b/iree/compiler/Conversion/LLVMToLLVM/FastExpConversion.cpp deleted file mode 100644 index df8451793afd..000000000000 --- a/iree/compiler/Conversion/LLVMToLLVM/FastExpConversion.cpp +++ /dev/null @@ -1,125 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/LLVMIR/LLVMTypes.h" -#include "mlir/IR/Builders.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -namespace mlir { -namespace iree_compiler { -namespace { - -// Fast polynomial approximation of exp(x) using its reduced range exp(y) -// where y is in the range [0, ln(2)], let y = x - floor(x / ln(2)) * ln(2) -// = x - k * ln(2), exp(x) = exp(y) * 2^k. exp(y) is computed with 4th degree -// polyomial: exp(y) = c0 + c1 * y + c2 * y^2 + c3 * y^3 + c4 * y^4 -struct FastExpConversionPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(LLVM::ExpOp op, - PatternRewriter &rewriter) const override { - constexpr float ln2Const = 0.693147181f; - constexpr float ln2InvConst = 1.44269504f; - - // Least squares polynomial fit computed : - // cValues = np.polyfit(np.linspace(0, math.log(2), 10000), np.exp(x), 4) - constexpr float cValues[5] = {0.05924867f, 0.15514645f, 0.50308552f, - 0.99968939f, 1.00000721531f}; - auto loc = op.getLoc(); - Value x = op.getOperand(); - - auto floatType = Float32Type::get(rewriter.getContext()); - auto i32Type = IntegerType::get(rewriter.getContext(), 32); - - Value ln2 = rewriter.create( - loc, floatType, rewriter.getF32FloatAttr(ln2Const)); - Value ln2Inv = rewriter.create( - loc, floatType, rewriter.getF32FloatAttr(ln2InvConst)); - - // Compute reduced range input y = x - floor(x / ln(2)) * ln(2) - Value xL2Inv = rewriter.create(loc, floatType, x, ln2Inv); - Value kF32 = rewriter.create(loc, floatType, xL2Inv); - Value kLn2 = rewriter.create(loc, floatType, kF32, ln2); - Value y = rewriter.create(loc, floatType, x, kLn2); - - SmallVector PConst(5); - for (int i = 0; i < 5; ++i) { - PConst[i] = rewriter.create( - loc, floatType, rewriter.getF32FloatAttr(cValues[i])); - } - // Evaluate exp(y) = sum(c[i] * y**i, i) - Value expY = rewriter.create(loc, floatType, y, PConst[0]); - expY = rewriter.create(loc, floatType, expY, PConst[1]); - expY = rewriter.create(loc, floatType, expY, y); - expY = rewriter.create(loc, floatType, expY, PConst[2]); - expY = rewriter.create(loc, floatType, expY, y); - expY = rewriter.create(loc, floatType, expY, PConst[3]); - expY = rewriter.create(loc, floatType, expY, y); - expY = rewriter.create(loc, floatType, expY, PConst[4]); - - // Compute exp2(k) with integer bitshift: - // exp2(k) = f32_bitcast((127 + k) << 23) - Value fPBias = rewriter.create( - loc, i32Type, rewriter.getI32IntegerAttr(127)); - Value k = rewriter.create(loc, i32Type, kF32); - Value kPlusfPBias = rewriter.create(loc, i32Type, k, fPBias); - Value shiftConst = rewriter.create( - loc, i32Type, rewriter.getI32IntegerAttr(23)); - Value twoPowkI = - rewriter.create(loc, i32Type, kPlusfPBias, shiftConst); - Value twoPowk = rewriter.create(loc, floatType, twoPowkI); - expY = rewriter.create(loc, floatType, expY, twoPowk); - rewriter.replaceOp(op, {expY}); - // TODO(ataei): Handle overflow and underflow cases (e.g |k| > 128). - return success(); - } -}; - -struct FastExpConversionPass - : public PassWrapper> { - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - void runOnOperation() override; -}; - -} // namespace - -void populateFastExpConversionPatterns(OwningRewritePatternList &patterns, - MLIRContext *context) { - patterns.insert(context); -} - -void FastExpConversionPass::runOnOperation() { - auto moduleOp = getOperation(); - auto context = moduleOp.getContext(); - OwningRewritePatternList patterns; - populateFastExpConversionPatterns(patterns, context); - (void)applyPatternsAndFoldGreedily(moduleOp, std::move(patterns)); -} - -std::unique_ptr> -createFastExpApproximationConversionPass() { - return std::make_unique(); -} - -static PassRegistration> pass( - "iree-codegen-linalg-to-llvm-fast-exp-conversion-pass", - "Convert llvm.intr.exp into its fast polynomial approximation version", - [] { return std::make_unique(); }); - -} // namespace iree_compiler -} // namespace mlir diff --git a/iree/compiler/Conversion/LLVMToLLVM/Passes.h b/iree/compiler/Conversion/LLVMToLLVM/Passes.h deleted file mode 100644 index e40fe90842b8..000000000000 --- a/iree/compiler/Conversion/LLVMToLLVM/Passes.h +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_COMPILER_CONVERSION_LLVMTOLLVM_PASSES_H_ -#define IREE_COMPILER_CONVERSION_LLVMTOLLVM_PASSES_H_ - -#include "mlir/Pass/Pass.h" - -namespace mlir { -namespace iree_compiler { - -// Creates a pass to rewrite llvm.intr.exp using its reduced range polynomial -// approximation. -std::unique_ptr> -createFastExpApproximationConversionPass(); - -} // namespace iree_compiler -} // namespace mlir -#endif // IREE_COMPILER_CONVERSION_LLVMTOLLVM_PASSES_H_ diff --git a/iree/compiler/Conversion/LinalgToLLVM/BUILD b/iree/compiler/Conversion/LinalgToLLVM/BUILD index ce8cd2b22d31..cbbeb4588e18 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/BUILD +++ b/iree/compiler/Conversion/LinalgToLLVM/BUILD @@ -40,7 +40,6 @@ cc_library( "//iree/compiler/Conversion/Common", "//iree/compiler/Conversion/HLOToHLO", "//iree/compiler/Conversion/HLOToLinalg", - "//iree/compiler/Conversion/LLVMToLLVM", "//iree/compiler/Dialect/HAL/IR", "//iree/compiler/Dialect/HAL/IR:HALDialect", "//iree/compiler/Dialect/IREE/IR", diff --git a/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt b/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt index 32bb64d19482..cc3cccd89083 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt +++ b/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt @@ -51,7 +51,6 @@ iree_cc_library( iree::compiler::Conversion::Common iree::compiler::Conversion::HLOToHLO iree::compiler::Conversion::HLOToLinalg - iree::compiler::Conversion::LLVMToLLVM iree::compiler::Dialect::HAL::IR iree::compiler::Dialect::HAL::IR::HALDialect iree::compiler::Dialect::IREE::IR diff --git a/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp b/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp index 140a7e320547..37e1bc0a695d 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp +++ b/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp @@ -642,6 +642,13 @@ void ConvertToLLVMPass::runOnOperation() { std::move(vectorToLoopsPatterns)); } + // math dialect elementry functions -> polynomial form. + { + OwningRewritePatternList mathPatterns; + populateMathPolynomialApproximationPatterns(mathPatterns, &getContext()); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(mathPatterns)); + } + auto module = getOperation(); LLVMTypeConverter converter(&getContext()); diff --git a/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp b/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp index c080d5190ddf..2fa6794b5dbc 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp +++ b/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp @@ -17,7 +17,6 @@ #include "iree/compiler/Conversion/Common/Attributes.h" #include "iree/compiler/Conversion/Common/Passes.h" #include "iree/compiler/Conversion/HLOToHLO/Passes.h" -#include "iree/compiler/Conversion/LLVMToLLVM/Passes.h" #include "iree/compiler/Conversion/LinalgToLLVM/Passes.h" #include "iree/compiler/Dialect/Shape/Transforms/Passes.h" #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" @@ -40,12 +39,6 @@ static llvm::cl::opt convImg2ColConversion( "linag.matmul"), llvm::cl::init(false)); -static llvm::cl::opt fastExpConversion( - "iree-codegen-linalg-to-llvm-fast-exp", - llvm::cl::desc("If true convert llvm.intr.exp into its range reduced " - "polynomial approximation."), - llvm::cl::init(false)); - void addLinalgToLLVMPasses(OpPassManager &passManager) { // Distribute linalg op among a 3d grid of parallel threads. Tile each // workgroup thread memory then vectorize the linalg op. @@ -85,11 +78,6 @@ void addLinalgToLLVMPasses(OpPassManager &passManager) { nestedModulePM.addPass(createCanonicalizerPass()); nestedModulePM.addPass(createCSEPass()); - - // Approximate llvm.intr.exp with a 4-th order ploynmial in range[0, ln2]. - if (fastExpConversion) { - nestedModulePM.addPass(createFastExpApproximationConversionPass()); - } } void buildLLVMTransformPassPipeline(OpPassManager &passManager) { diff --git a/iree/test/e2e/llvm_specific/BUILD b/iree/test/e2e/llvm_specific/BUILD index dffd9477264c..9c7d633670ab 100644 --- a/iree/test/e2e/llvm_specific/BUILD +++ b/iree/test/e2e/llvm_specific/BUILD @@ -34,15 +34,3 @@ iree_check_single_backend_test_suite( driver = "dylib", target_backend = "dylib-llvm-aot", ) - -iree_check_single_backend_test_suite( - name = "check_llvm-aot-exponential_fast", - srcs = [ - "exponential.mlir", - ], - compiler_flags = [ - "-iree-codegen-linalg-to-llvm-fast-exp=true", - ], - driver = "dylib", - target_backend = "dylib-llvm-aot", -) diff --git a/iree/test/e2e/llvm_specific/CMakeLists.txt b/iree/test/e2e/llvm_specific/CMakeLists.txt index 66272613afab..b7f1add319bc 100644 --- a/iree/test/e2e/llvm_specific/CMakeLists.txt +++ b/iree/test/e2e/llvm_specific/CMakeLists.txt @@ -23,17 +23,4 @@ iree_check_single_backend_test_suite( "-iree-codegen-linalg-to-llvm-conv-img2col-conversion=true" ) -iree_check_single_backend_test_suite( - NAME - check_llvm-aot-exponential_fast - SRCS - "exponential.mlir" - TARGET_BACKEND - "dylib-llvm-aot" - DRIVER - "dylib" - COMPILER_FLAGS - "-iree-codegen-linalg-to-llvm-fast-exp=true" -) - ### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/iree/test/e2e/llvm_specific/exponential.mlir b/iree/test/e2e/llvm_specific/exponential.mlir deleted file mode 100644 index 6e9132698a41..000000000000 --- a/iree/test/e2e/llvm_specific/exponential.mlir +++ /dev/null @@ -1,27 +0,0 @@ -func @tensor() attributes { iree.module.export } { - %input = iree.unfoldable_constant dense<[0.0, 1.0, 2.0, 4.0]> : tensor<4xf32> - %result = "mhlo.exponential"(%input) : (tensor<4xf32>) -> tensor<4xf32> - check.expect_almost_eq_const(%result, dense<[1.0, 2.7183, 7.3891, 54.5981]> : tensor<4xf32>) : tensor<4xf32> - return -} - -func @scalar() attributes { iree.module.export } { - %input = iree.unfoldable_constant dense<1.0> : tensor - %result = "mhlo.exponential"(%input) : (tensor) -> tensor - check.expect_almost_eq_const(%result, dense<2.7183> : tensor) : tensor - return -} - -func @double() attributes { iree.module.export } { - %input = iree.unfoldable_constant dense<1.0> : tensor - %result = "mhlo.exponential"(%input) : (tensor) -> tensor - check.expect_almost_eq_const(%result, dense<2.7183> : tensor) : tensor - return -} - -func @negative() attributes { iree.module.export } { - %input = iree.unfoldable_constant dense<-1.0> : tensor - %result = "mhlo.exponential"(%input) : (tensor) -> tensor - check.expect_almost_eq_const(%result, dense<0.367879> : tensor) : tensor - return -} From e1136e30a788d2eb14271017d8202e8491616cc5 Mon Sep 17 00:00:00 2001 From: Thomas Date: Thu, 11 Mar 2021 18:49:45 -0800 Subject: [PATCH 17/17] Fix use after free detected by asan (#5080) --- .../ConcretizeTileAmongWorkgroupsPass.cpp | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/iree/compiler/Conversion/LinalgToSPIRV/ConcretizeTileAmongWorkgroupsPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/ConcretizeTileAmongWorkgroupsPass.cpp index 6973241ddae7..5a68092a010f 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/ConcretizeTileAmongWorkgroupsPass.cpp +++ b/iree/compiler/Conversion/LinalgToSPIRV/ConcretizeTileAmongWorkgroupsPass.cpp @@ -148,7 +148,8 @@ LogicalResult getInputOutputTypesForAllTiles( /// /// TODO(antiagainst): This pass can be shared between CPU and GPU. But the /// following query scopes it to GPU for now. -llvm::Optional, ArrayRef>> +llvm::Optional< + std::pair, llvm::SmallVector>> getTileSizeAndWorkgroupSize(Operation *rootOp, ArrayRef inputTypes, ArrayRef outputTypes) { // Build necesary structures to query the tile sizes for distributing to @@ -188,7 +189,8 @@ getTileSizeAndWorkgroupSize(Operation *rootOp, ArrayRef inputTypes, // preparation. launchConfig->finalize(rootOp->getParentOfType()); - return std::make_pair(tileSize, workgroupSize); + return std::make_pair(llvm::to_vector<4>(tileSize), + llvm::to_vector<4>(workgroupSize)); } /// Replaces hal.interface.workgroup.size op with the constant value chosen @@ -438,9 +440,10 @@ class ConcretizeTileAmongWorkgroupsPass // The tile sizes are specified against the original dimension order of // the workload shape. But Flow/HAL processor id/size/count ops' are // created using the reverse order. - tileSize = llvm::to_vector<4>( - llvm::reverse(sizes->first.take_front(numTilableDims))); - workgroupSize = llvm::to_vector<4>(sizes->second); + tileSize = sizes->first; + tileSize.resize(numTilableDims); + tileSize = llvm::to_vector<4>(llvm::reverse(tileSize)); + workgroupSize = sizes->second; } else { return funcOp.emitError("failed to query tile size and workgroup size"); }