Skip to content

Commit

Permalink
Added Arith to StableHLO pass as an option and move to seperate file.
Browse files Browse the repository at this point in the history
  • Loading branch information
uazizTT committed Oct 20, 2024
1 parent fc69a9f commit 4fa4e66
Show file tree
Hide file tree
Showing 9 changed files with 124 additions and 68 deletions.
19 changes: 19 additions & 0 deletions include/ttmlir/Conversion/ArithToStableHLO/ArithToStableHLO.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_CONVERSION_ARITHTOSTABLEHLO_ARITHTOSTABLEHLO_H
#define TTMLIR_CONVERSION_ARITHTOSTABLEHLO_ARITHTOSTABLEHLO_H

#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"

namespace mlir::tt {

#ifdef TTMLIR_ENABLE_STABLEHLO
std::unique_ptr<OperationPass<ModuleOp>> createConvertArithToStableHLOPass();
#endif

} // namespace mlir::tt

#endif // TTMLIR_CONVERSION_STABLEHLOTOTTIR_STABLEHLOTOTTIR_H
1 change: 1 addition & 0 deletions include/ttmlir/Conversion/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#define TTMLIR_CONVERSION_PASSES_H

#ifdef TTMLIR_ENABLE_STABLEHLO
#include "ttmlir/Conversion/ArithToStableHLO/ArithToStableHLO.h"
#include "ttmlir/Conversion/StableHLOToTTIR/StableHLOToTTIR.h"
#endif
#include "ttmlir/Conversion/TTIRToTTMetal/TTIRToTTMetal.h"
Expand Down
5 changes: 0 additions & 5 deletions include/ttmlir/Conversion/StableHLOToTTIR/StableHLOToTTIR.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,6 @@ void populateStableHLOToTTIRPatterns(MLIRContext *ctx,
TypeConverter &typeConverter);

std::unique_ptr<OperationPass<ModuleOp>> createConvertStableHLOToTTIRPass();

void populateArithToStableHLOPatterns(MLIRContext *ctx,
RewritePatternSet &patterns,
TypeConverter &typeConverter);
std::unique_ptr<OperationPass<ModuleOp>> createConvertArithToStableHLOPass();
#endif

} // namespace mlir::tt
Expand Down
7 changes: 7 additions & 0 deletions include/ttmlir/Dialect/TTIR/Pipelines/TTIRPipelines.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ struct StableHLOToTTIRPipelineOptions
// Currently this pass fails if module has a name, so keeping the
// optimization OFF by default until that issue is fixed on llvm side.
llvm::cl::init(false)};
Option<bool> arithDialectConversionsEnabled{
*this, "enable-arith-to-stablehlo",
llvm::cl::desc("Enable Arith to StableHLO conversion pass."),
// Currently torch-mlir front-end does not convert ConstantOp for Arith
// Dialect to StableHLO. This pass makes those conversions until this
// is fixed in the upstream torch-mlir.
llvm::cl::init(true)};
};

void createStableHLOToTTIRPipeline(
Expand Down
93 changes: 93 additions & 0 deletions lib/Conversion/StableHLOToTTIR/ArithToStableHLOPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "ttmlir/Conversion/ArithToStableHLO/ArithToStableHLO.h"

#include <mlir/Dialect/Arith/IR/Arith.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Dialect/Func/Transforms/FuncConversions.h>
#include <mlir/Dialect/Tensor/IR/Tensor.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/Dialect.h>
#include <mlir/IR/PatternMatch.h>
#include <mlir/Pass/Pass.h>

#include <stablehlo/dialect/StablehloOps.h>

#include "ttmlir/Dialect/TT/IR/TT.h"
#include "ttmlir/Dialect/TTIR/IR/TTIR.h"

using namespace mlir;
using namespace mlir::tt;

namespace mlir::tt::ttir {

#define GEN_PASS_DEF_CONVERTARITHTOSTABLEHLO
#include "ttmlir/Conversion/Passes.h.inc"

} // namespace mlir::tt::ttir

namespace {

class ArithToStableHLOConstantOpConversionPattern
: public OpConversionPattern<mlir::arith::ConstantOp> {

using OpConversionPattern<mlir::arith::ConstantOp>::OpConversionPattern;

public:
LogicalResult
matchAndRewrite(mlir::arith::ConstantOp srcOp,
mlir::arith::ConstantOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(srcOp,
srcOp.getValue());
return success();
}
};

struct ConvertArithToStableHLOPass
: public ttir::impl::ConvertArithToStableHLOBase<
ConvertArithToStableHLOPass> {
void runOnOperation() final {
mlir::ConversionTarget target(getContext());

target.addIllegalDialect<mlir::arith::ArithDialect>();
target.addLegalDialect<mlir::stablehlo::StablehloDialect>();
target.addLegalOp<mlir::ModuleOp>();
target.addLegalOp<mlir::func::FuncOp>();
target.addLegalOp<mlir::func::ReturnOp>();

// For now keep the same type assuming StableHLO ops operate on builtin
// tensor.
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) {
assert(isa<RankedTensorType>(type) &&
"only ranked tensor type supported");
return type;
});
RewritePatternSet patterns(&getContext());

// Convert Arith ConstantOp to StableHLO ConstantOp
patterns.add<ArithToStableHLOConstantOpConversionPattern>(typeConverter,
&getContext());

// Apply conversion.
if (failed(
applyFullConversion(getOperation(), target, std::move(patterns)))) {
signalPassFailure();
return;
}
}
};

} // namespace

namespace mlir::tt {

std::unique_ptr<OperationPass<ModuleOp>> createConvertArithToStableHLOPass() {
return std::make_unique<ConvertArithToStableHLOPass>();
}

} // namespace mlir::tt
1 change: 1 addition & 0 deletions lib/Conversion/StableHLOToTTIR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ include_directories(${PROJECT_SOURCE_DIR}/include)
add_mlir_library(TTMLIRStableHLOToTTIR
StableHLOToTTIRPatterns.cpp
StableHLOToTTIRPass.cpp
ArithToStableHLOPass.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/ttmlir/Conversion/StableHLOToTTIR
Expand Down
38 changes: 0 additions & 38 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ using namespace mlir::tt;
namespace mlir::tt::ttir {

#define GEN_PASS_DEF_CONVERTSTABLEHLOTOTTIR
#define GEN_PASS_DEF_CONVERTARITHTOSTABLEHLO
#include "ttmlir/Conversion/Passes.h.inc"

} // namespace mlir::tt::ttir
Expand Down Expand Up @@ -93,39 +92,6 @@ struct ConvertStableHLOToTTIRPass
}
};

struct ConvertArithToStableHLOPass
: public ttir::impl::ConvertArithToStableHLOBase<
ConvertArithToStableHLOPass> {
void runOnOperation() final {
mlir::ConversionTarget target(getContext());

target.addIllegalDialect<mlir::arith::ArithDialect>();
target.addLegalDialect<mlir::stablehlo::StablehloDialect>();
target.addLegalOp<mlir::ModuleOp>();
target.addLegalOp<mlir::func::FuncOp>();
target.addLegalOp<mlir::func::ReturnOp>();

// For now keep the same type assuming StableHLO ops operate on builtin
// tensor.
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) {
assert(isa<RankedTensorType>(type) &&
"only ranked tensor type supported");
return type;
});
RewritePatternSet patterns(&getContext());

populateArithToStableHLOPatterns(&getContext(), patterns, typeConverter);

// Apply conversion.
if (failed(
applyFullConversion(getOperation(), target, std::move(patterns)))) {
signalPassFailure();
return;
}
}
};

} // namespace

namespace mlir::tt {
Expand All @@ -134,8 +100,4 @@ std::unique_ptr<OperationPass<ModuleOp>> createConvertStableHLOToTTIRPass() {
return std::make_unique<ConvertStableHLOToTTIRPass>();
}

std::unique_ptr<OperationPass<ModuleOp>> createConvertArithToStableHLOPass() {
return std::make_unique<ConvertArithToStableHLOPass>();
}

} // namespace mlir::tt
24 changes: 0 additions & 24 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include "ttmlir/Conversion/StableHLOToTTIR/StableHLOToTTIR.h"

#include "mlir/Dialect/Traits.h"
#include <mlir/Dialect/Arith/IR/Arith.h>
#include <mlir/Dialect/Func/Transforms/FuncConversions.h>
#include <mlir/Dialect/Tensor/IR/Tensor.h>
#include <mlir/IR/BuiltinAttributes.h>
Expand Down Expand Up @@ -880,27 +879,4 @@ void populateStableHLOToTTIRPatterns(MLIRContext *ctx,
addReshapeOpConversionPattern(ctx, patterns, typeConverter);
}

class ArithToStableHLOConstantOpConversionPattern
: public OpConversionPattern<mlir::arith::ConstantOp> {

using OpConversionPattern<mlir::arith::ConstantOp>::OpConversionPattern;

public:
LogicalResult
matchAndRewrite(mlir::arith::ConstantOp srcOp,
mlir::arith::ConstantOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(srcOp,
srcOp.getValue());
return success();
}
};

void populateArithToStableHLOPatterns(MLIRContext *ctx,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {
patterns.add<ArithToStableHLOConstantOpConversionPattern>(typeConverter, ctx);
}

} // namespace mlir::tt
4 changes: 3 additions & 1 deletion lib/Dialect/TTIR/Pipelines/TTIRPipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ namespace mlir::tt::ttir {
#ifdef TTMLIR_ENABLE_STABLEHLO
void createStableHLOToTTIRPipeline(
OpPassManager &pm, const StableHLOToTTIRPipelineOptions &options) {
pm.addPass(createConvertArithToStableHLOPass());
if (options.arithDialectConversionsEnabled) {
pm.addPass(createConvertArithToStableHLOPass());
}
pm.addPass(createConvertStableHLOToTTIRPass());
if (options.removeDeadValuesEnabled) {
pm.addPass(mlir::createRemoveDeadValuesPass());
Expand Down

0 comments on commit 4fa4e66

Please sign in to comment.