Skip to content

Commit

Permalink
Remove Usage of Deprecated FunctionPass (llvm#1134)
Browse files Browse the repository at this point in the history
* remove usage of deprected FunctionPass

Signed-off-by: ian Bearman <[email protected]>
Co-authored-by: Alexandre Eichenberger <[email protected]>
  • Loading branch information
manbearian and AlexandreEichenberger authored Feb 1, 2022
1 parent af02912 commit c09cf01
Show file tree
Hide file tree
Showing 13 changed files with 48 additions and 41 deletions.
16 changes: 11 additions & 5 deletions src/Conversion/KrnlToAffine/KrnlToAffine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -432,13 +432,13 @@ void lowerGetInductionVariableValueOp(KrnlGetInductionVariableValueOp &getIVOp,
/// At this stage the dialect will contain standard operations as well like
/// add and multiply, this pass will leave these operations intact.
struct ConvertKrnlToAffinePass
: public PassWrapper<ConvertKrnlToAffinePass, FunctionPass> {
: public PassWrapper<ConvertKrnlToAffinePass, OperationPass<FuncOp>> {

StringRef getArgument() const override { return "convert-krnl-to-affine"; }

StringRef getDescription() const override { return "Lower Krnl dialect."; }

void runOnFunction() final;
void runOnOperation() final;
};

LogicalResult interpretOperation(Operation *op, OpBuilder &builder,
Expand Down Expand Up @@ -1443,9 +1443,15 @@ void markLoopBodyAsMovable(
}
}

void ConvertKrnlToAffinePass::runOnFunction() {
void ConvertKrnlToAffinePass::runOnOperation() {
OpBuilder builder(&getContext());
FuncOp funcOp = getFunction();
FuncOp funcOp = getOperation();

// external function: nothing to do
if (funcOp.body().empty()) {
return;
}

// Move invariant instructions outside of the loops as many as possible. This
// helps make loops perfectly nested, which facilitates transformations.
funcOp.walk([&](KrnlIterateOp loopOp) {
Expand Down Expand Up @@ -1530,7 +1536,7 @@ void ConvertKrnlToAffinePass::runOnFunction() {

DenseSet<Operation *> unconverted;
if (failed(applyPartialConversion(
getFunction(), target, std::move(patterns), &unconverted))) {
getOperation(), target, std::move(patterns), &unconverted))) {
{
const std::lock_guard<std::mutex> lock(unrollAndJamMutex);
unrollAndJamMap.erase(currFuncOp);
Expand Down
6 changes: 3 additions & 3 deletions src/Transform/BundleMemoryPools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ class KrnlMoveConstantsUp : public OpRewritePattern<arith::ConstantOp> {
*/

class KrnlBundleMemoryPoolsPass
: public PassWrapper<KrnlBundleMemoryPoolsPass, FunctionPass> {
: public PassWrapper<KrnlBundleMemoryPoolsPass, OperationPass<FuncOp>> {

BlockToMemPool blockToStaticPool;
BlockToMemPool blockToDynamicPool;
Expand All @@ -518,8 +518,8 @@ class KrnlBundleMemoryPoolsPass
return "Bundle memory pools of internal MemRefs into a single memory pool.";
}

void runOnFunction() override {
auto function = getFunction();
void runOnOperation() override {
auto function = getOperation();

ConversionTarget target(getContext());
RewritePatternSet patterns(&getContext());
Expand Down
7 changes: 4 additions & 3 deletions src/Transform/DisconnectKrnlDimFromAlloc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,16 +129,17 @@ class DisconnectKrnlDimFromAlloc : public OpRewritePattern<KrnlDimOp> {
* Function pass that disconnects krnl.dim emission from its MemRef alloc.
*/
class DisconnectKrnlDimFromAllocPass
: public PassWrapper<DisconnectKrnlDimFromAllocPass, FunctionPass> {
: public PassWrapper<DisconnectKrnlDimFromAllocPass,
OperationPass<FuncOp>> {
public:
StringRef getArgument() const override { return "lower-krnl-shape-to-std"; }

StringRef getDescription() const override {
return "Lowers krnl shape-related operations.";
}

void runOnFunction() override {
auto function = getFunction();
void runOnOperation() override {
auto function = getOperation();

ConversionTarget target(getContext());
RewritePatternSet patterns(&getContext());
Expand Down
6 changes: 3 additions & 3 deletions src/Transform/ElideKrnlGlobalConstants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,16 @@ namespace {
* Function pass that performs constant value elision of Krnl globals.
*/
class ElideConstGlobalValuePass
: public PassWrapper<ElideConstGlobalValuePass, FunctionPass> {
: public PassWrapper<ElideConstGlobalValuePass, OperationPass<FuncOp>> {
public:
StringRef getArgument() const override { return "elide-krnl-constants"; }

StringRef getDescription() const override {
return "Elide the constant values of the Global Krnl operations.";
}

void runOnFunction() override {
auto function = getFunction();
void runOnOperation() override {
auto function = getOperation();

ConversionTarget target(getContext());
RewritePatternSet patterns(&getContext());
Expand Down
6 changes: 3 additions & 3 deletions src/Transform/EnableMemoryPool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,16 +185,16 @@ class KrnlEliminateOldDealloc : public OpRewritePattern<memref::DeallocOp> {
* Function pass that enables memory pooling for MemRefs.
*/
class KrnlEnableMemoryPoolPass
: public PassWrapper<KrnlEnableMemoryPoolPass, FunctionPass> {
: public PassWrapper<KrnlEnableMemoryPoolPass, OperationPass<FuncOp>> {
public:
StringRef getArgument() const override { return "enable-memory-pool"; }

StringRef getDescription() const override {
return "Enable a memory pool for allocating internal MemRefs.";
}

void runOnFunction() override {
auto function = getFunction();
void runOnOperation() override {
auto function = getOperation();

ConversionTarget target(getContext());
RewritePatternSet patterns(&getContext());
Expand Down
6 changes: 3 additions & 3 deletions src/Transform/LowerKrnlShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,16 @@ class LowerKrnlShape : public OpRewritePattern<KrnlShapeOp> {
* Function pass that emits the shape of a MemRef.
*/
class LowerKrnlShapePass
: public PassWrapper<LowerKrnlShapePass, FunctionPass> {
: public PassWrapper<LowerKrnlShapePass, OperationPass<FuncOp>> {
public:
StringRef getArgument() const override { return "lower-krnl-shape"; }

StringRef getDescription() const override {
return "Lower krnl.shape operation to use Shape dialect operations.";
}

void runOnFunction() override {
auto function = getFunction();
void runOnOperation() override {
auto function = getOperation();

ConversionTarget target(getContext());
RewritePatternSet patterns(&getContext());
Expand Down
8 changes: 4 additions & 4 deletions src/Transform/ONNX/ConstProp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,7 @@ class ConstPropScatterNDPattern : public OpRewritePattern<ONNXScatterNDOp> {
//===----------------------------------------------------------------------===//

struct ConstPropONNXToONNXPass
: public PassWrapper<ConstPropONNXToONNXPass, FunctionPass> {
: public PassWrapper<ConstPropONNXToONNXPass, OperationPass<FuncOp>> {

StringRef getArgument() const override { return "constprop-onnx"; }

Expand All @@ -754,12 +754,12 @@ struct ConstPropONNXToONNXPass
"other ONNX operations.";
}

void runOnFunction() final;
void runOnOperation() final;
};
} // end anonymous namespace.

void ConstPropONNXToONNXPass::runOnFunction() {
auto function = getFunction();
void ConstPropONNXToONNXPass::runOnOperation() {
auto function = getOperation();
MLIRContext *context = &getContext();

ConversionTarget target(getContext());
Expand Down
8 changes: 4 additions & 4 deletions src/Transform/ONNX/Decompose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ Value createSequenceConstructOp(
namespace {

struct DecomposeONNXToONNXPass
: public PassWrapper<DecomposeONNXToONNXPass, FunctionPass> {
: public PassWrapper<DecomposeONNXToONNXPass, OperationPass<FuncOp>> {

StringRef getArgument() const override { return "decompose-onnx"; }

Expand All @@ -132,12 +132,12 @@ struct DecomposeONNXToONNXPass
"operations.";
}

void runOnFunction() final;
void runOnOperation() final;
};
} // end anonymous namespace.

void DecomposeONNXToONNXPass::runOnFunction() {
auto function = getFunction();
void DecomposeONNXToONNXPass::runOnOperation() {
auto function = getOperation();
MLIRContext *context = &getContext();

ConversionTarget target(getContext());
Expand Down
6 changes: 3 additions & 3 deletions src/Transform/ONNX/ElideConstants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,16 @@ class ConstantValueElision : public OpRewritePattern<ONNXConstantOp> {
* Function pass that performs constant value elision.
*/
class ElideConstantValuePass
: public PassWrapper<ElideConstantValuePass, FunctionPass> {
: public PassWrapper<ElideConstantValuePass, OperationPass<FuncOp>> {
public:
StringRef getArgument() const override { return "elide-constants"; }

StringRef getDescription() const override {
return "Elide values of constant operations.";
}

void runOnFunction() override {
auto function = getFunction();
void runOnOperation() override {
auto function = getOperation();

ConversionTarget target(getContext());
RewritePatternSet patterns(&getContext());
Expand Down
6 changes: 3 additions & 3 deletions src/Transform/ONNX/InstrumentONNXPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ llvm::cl::bits<InstrumentActions> InstrumentControlBits(
llvm::cl::cat(OMPassOptions));

class InstrumentONNXPass
: public mlir::PassWrapper<InstrumentONNXPass, FunctionPass> {
: public mlir::PassWrapper<InstrumentONNXPass, OperationPass<FuncOp>> {

private:
bool allOpsAllowed;
Expand All @@ -84,13 +84,13 @@ class InstrumentONNXPass
runtimeActions = InstrumentControlBits.getBits();
};

void runOnFunction() override {
void runOnOperation() override {
if (instrumentONNXOps == "" || instrumentONNXOps == "NONE")
return;
init(instrumentONNXOps);

// Iterate on the operations nested in this function
getFunction().walk([&](mlir::Operation *op) {
getOperation().walk([&](mlir::Operation *op) {
if (isa<mlir::ONNXOpsDialect>(op->getDialect())) {
// Skip the prefix "onnx." of onnx op name
const char *opName = op->getName().getStringRef().data() + 5;
Expand Down
6 changes: 3 additions & 3 deletions src/Transform/ONNX/ONNXPreKrnlVerifyPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ namespace {
*/

class ONNXPreKrnlVerifyPass
: public mlir::PassWrapper<ONNXPreKrnlVerifyPass, FunctionPass> {
: public mlir::PassWrapper<ONNXPreKrnlVerifyPass, OperationPass<FuncOp>> {

public:
StringRef getArgument() const override { return "onnx-pre-krnl-verify"; }

StringRef getDescription() const override { return "Verify onnx ops."; }

void runOnFunction() override {
auto function = getFunction();
void runOnOperation() override {
auto function = getOperation();
auto &funcBody = function.getBody();

// Iterate on the operations
Expand Down
2 changes: 1 addition & 1 deletion src/Transform/ONNX/ShapeInferencePass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ static SmallVector<mlir::FuncOp, 4> lookUpFuncsMatching(
}

/*!
* FunctionPass that performs shape inference by iterating over a list of
* Function pass that performs shape inference by iterating over a list of
* candidate operations and propagating the shape information until the list
* of operations is empty [credit MLIR authors].
*
Expand Down
6 changes: 3 additions & 3 deletions src/Transform/OptimizeMemoryPools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -868,7 +868,7 @@ class KrnlCompactStaticMemoryPools : public OpRewritePattern<memref::AllocOp> {
* Function pass that optimizes memory pools.
*/
class KrnlOptimizeMemoryPoolsPass
: public PassWrapper<KrnlOptimizeMemoryPoolsPass, FunctionPass> {
: public PassWrapper<KrnlOptimizeMemoryPoolsPass, OperationPass<FuncOp>> {
BlockToCompactedAlignments blockToStaticPoolAlignments;
BlockToDiscardedGetRefs blockToDiscardedGetRefs;

Expand All @@ -879,8 +879,8 @@ class KrnlOptimizeMemoryPoolsPass
return "Optimize the static and dynamic memory pools.";
}

void runOnFunction() override {
auto function = getFunction();
void runOnOperation() override {
auto function = getOperation();

ConversionTarget target(getContext());
RewritePatternSet patterns(&getContext());
Expand Down

0 comments on commit c09cf01

Please sign in to comment.