diff --git a/compiler/src/iree/compiler/ConstEval/BUILD.bazel b/compiler/src/iree/compiler/ConstEval/BUILD.bazel index 2751d4855dc4..d7aab29d53fe 100644 --- a/compiler/src/iree/compiler/ConstEval/BUILD.bazel +++ b/compiler/src/iree/compiler/ConstEval/BUILD.bazel @@ -53,7 +53,9 @@ iree_compiler_cc_library( ":PassHeaders", ":PassesIncGen", ":Runtime", + "//compiler/src/iree/compiler/Dialect/Flow/IR", "//compiler/src/iree/compiler/Dialect/HAL/Target", + "//compiler/src/iree/compiler/Dialect/Stream/IR", "//compiler/src/iree/compiler/Dialect/Util/Analysis/Constant", "//compiler/src/iree/compiler/Dialect/Util/IR", "//compiler/src/iree/compiler/Pipelines", @@ -63,6 +65,7 @@ iree_compiler_cc_library( "@llvm-project//mlir:FunctionInterfaces", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:TransformUtils", ], ) diff --git a/compiler/src/iree/compiler/ConstEval/CMakeLists.txt b/compiler/src/iree/compiler/ConstEval/CMakeLists.txt index 2e4701646ae3..43b69bb02aed 100644 --- a/compiler/src/iree/compiler/ConstEval/CMakeLists.txt +++ b/compiler/src/iree/compiler/ConstEval/CMakeLists.txt @@ -49,7 +49,10 @@ iree_cc_library( MLIRFunctionInterfaces MLIRIR MLIRPass + MLIRTransformUtils + iree::compiler::Dialect::Flow::IR iree::compiler::Dialect::HAL::Target + iree::compiler::Dialect::Stream::IR iree::compiler::Dialect::Util::Analysis::Constant iree::compiler::Dialect::Util::IR iree::compiler::Pipelines diff --git a/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp b/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp index 9a4d89a65818..84ba1c43fc1f 100644 --- a/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp +++ b/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp @@ -6,7 +6,9 @@ #include "iree/compiler/ConstEval/Passes.h" #include "iree/compiler/ConstEval/Runtime.h" +#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" #include "iree/compiler/Dialect/HAL/Target/TargetOptions.h" +#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h" #include "iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.h" #include "iree/compiler/Dialect/Util/Analysis/Constant/OpOracle.h" #include "iree/compiler/Dialect/Util/IR/UtilOps.h" @@ -21,7 +23,9 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/IRMapping.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/IR/SymbolTable.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include @@ -448,6 +452,63 @@ static LogicalResult cloneUsedObjects(FunctionOpInterface funcOp, return success(); } +// Remove flow.tensor.transfer ops. +struct StripFlowTensorTransferPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(IREE::Flow::TensorTransferOp op, + PatternRewriter &rewriter) const override { + rewriter.replaceAllUsesWith(op.getResult(), op.getOperand()); + rewriter.eraseOp(op); + return success(); + } +}; + +void stripStreamAffinityAttributes(IREE::Util::FuncOp funcOp) { + // Removes affinity attributes from funcOp and nested ops. + // Removes only from dictionaries (includes op's named attributes). + // We are not removing affinity attributes form an ArrayAttr as it is not + // clear if this would not silently change the semantics of an operation. + AttrTypeReplacer replacer; + SmallVector namedAttrs; + replacer.addReplacement( + [&](Attribute originalAttr) -> std::optional { + if (auto dictAttr = dyn_cast(originalAttr)) { + namedAttrs.clear(); + for (auto &namedAttr : dictAttr) { + if (!isa(namedAttr.getValue())) { + namedAttrs.push_back(namedAttr); + } + } + if (namedAttrs.size() == dictAttr.size()) { + // Nothing to do, no affinity attributes in this dict. + return std::nullopt; + } else { + return DictionaryAttr::getWithSorted(dictAttr.getContext(), + namedAttrs); + } + } + return std::nullopt; + }); + replacer.recursivelyReplaceElementsIn(funcOp); +} + +// Remove device/queue affinities for the IR. +// E.g. remove `flow.tensor.transfer` ops. +static LogicalResult +stripExecutionContextAffinities(IREE::Util::FuncOp funcOp) { + RewritePatternSet patterns(funcOp->getContext()); + patterns.add(funcOp.getContext()); + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + return emitError(funcOp->getLoc()) + << "Stripping execution context affinities failed"; + } + + stripStreamAffinityAttributes(funcOp); + return success(); +} + class ProgramBuilder { public: ProgramBuilder(ModuleOp sourceModuleOp, @@ -495,6 +556,7 @@ class ProgramBuilder { funcOp.erase(); return failure(); } + return success(); } @@ -598,6 +660,10 @@ class ProgramBuilder { termBuilder.create(funcOp.getLoc(), returns); funcOp.setType(termBuilder.getFunctionType(argumentTypes, returnTypes)); + if (failed(stripExecutionContextAffinities(funcOp))) { + return failure(); + } + jitFunctions.push_back(std::move(desc)); return success(); } diff --git a/compiler/src/iree/compiler/ConstEval/test/jit_globals.mlir b/compiler/src/iree/compiler/ConstEval/test/jit_globals.mlir index fa5dd913411f..629091b6b278 100644 --- a/compiler/src/iree/compiler/ConstEval/test/jit_globals.mlir +++ b/compiler/src/iree/compiler/ConstEval/test/jit_globals.mlir @@ -439,3 +439,43 @@ module @dispatch_executable { util.return } } + +// ----- + +// CHECK-LABEL: @strip_flow_tensor_transfer +// CHECK: util.global private @[[EVALED:.+]] = dense<2.000000e+02> : tensor<2xf16> +module @strip_flow_tensor_transfer { + util.global private @hoisted : tensor<2xf16> + // CHECK-NOT: util.initializer + util.initializer { + %cst = arith.constant dense<2.0e+2> : tensor<2xf16> + %cst_transfered = flow.tensor.transfer %cst : tensor<2xf16> to #hal.device.promise<@dev_a> + util.global.store %cst_transfered, @hoisted : tensor<2xf16> + util.return + } + util.func public @main() -> tensor<2xf16> { + // CHECK: util.global.load @[[EVALED]] + %hoisted = util.global.load @hoisted : tensor<2xf16> + util.return %hoisted : tensor<2xf16> + } +} + +// ----- + +// CHECK-LABEL: @strip_stream_affinity_attribute +// CHECK: util.global private @[[EVALED:.+]] = dense<1> : tensor<2xi32> +module @strip_stream_affinity_attribute { + util.global private @hoisted : tensor<2xi32> + // CHECK-NOT: util.initializer + util.initializer { + %c1 = arith.constant 1 : i32 + %tensor = flow.tensor.splat %c1 : tensor<2xi32> attributes { stream.affinity = #hal.device.promise<@dev_a> } + util.global.store %tensor , @hoisted : tensor<2xi32> + util.return + } + util.func public @main() -> tensor<2xi32> { + // CHECK: util.global.load @[[EVALED]] + %hoisted = util.global.load @hoisted : tensor<2xi32> + util.return %hoisted : tensor<2xi32> + } +}