From 35ec030c86f59f45379436304b74c48717981354 Mon Sep 17 00:00:00 2001 From: "Tung D. Le" Date: Mon, 4 Mar 2024 16:39:32 +0900 Subject: [PATCH] Clean up ElideKrnlGlobalConstants Signed-off-by: Tung D. Le --- src/Tools/onnx-mlir-opt/RegisterPasses.cpp | 4 - src/Transform/CMakeLists.txt | 8 -- src/Transform/ElideKrnlGlobalConstants.cpp | 115 --------------------- src/Transform/ElideKrnlGlobalConstants.hpp | 39 ------- test/mlir/krnl/krnl_global_elision.mlir | 28 ----- 5 files changed, 194 deletions(-) delete mode 100644 src/Transform/ElideKrnlGlobalConstants.cpp delete mode 100644 src/Transform/ElideKrnlGlobalConstants.hpp delete mode 100644 test/mlir/krnl/krnl_global_elision.mlir diff --git a/src/Tools/onnx-mlir-opt/RegisterPasses.cpp b/src/Tools/onnx-mlir-opt/RegisterPasses.cpp index 82a8fe8d46..2e4f8e2c98 100644 --- a/src/Tools/onnx-mlir-opt/RegisterPasses.cpp +++ b/src/Tools/onnx-mlir-opt/RegisterPasses.cpp @@ -96,10 +96,6 @@ void registerOMPasses(int optLevel) { return createProcessScfParallelPrivatePass(); }); - mlir::registerPass([]() -> std::unique_ptr { - return createElideConstGlobalValuePass(); - }); - mlir::registerPass([]() -> std::unique_ptr { return krnl::createConvertSeqToMemrefPass(); }); diff --git a/src/Transform/CMakeLists.txt b/src/Transform/CMakeLists.txt index 7623171f9a..240f74b4e5 100644 --- a/src/Transform/CMakeLists.txt +++ b/src/Transform/CMakeLists.txt @@ -1,13 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -add_onnx_mlir_library(OMElideKrnlGlobalConstants - ElideKrnlGlobalConstants.cpp - - LINK_LIBS PUBLIC - OMKrnlOps - MLIRTransformUtils - ) - add_onnx_mlir_library(OMLowerKrnlRegion LowerKrnlRegion.cpp diff --git a/src/Transform/ElideKrnlGlobalConstants.cpp b/src/Transform/ElideKrnlGlobalConstants.cpp deleted file mode 100644 index 5c5bca4d8c..0000000000 --- a/src/Transform/ElideKrnlGlobalConstants.cpp +++ /dev/null @@ -1,115 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - */ - -//===- ElideKrnlGlobalConstants.cpp - Krnl Constant lobal Value Elision ---===// -// -// Copyright 2019-2022 The IBM Research Authors. -// -// ============================================================================= -// -// In practice, the constant values of Global Krnl operations may be large -// enough to hinder the readability of the MLIR intermediate representation. -// -// This file creates a pass which elides the explicit values of constant -// global operations. This pass has purely cosmetic purposes and should only be -// run to obtain a compact representation of the program when emitting Krnl -// dialect code. This pass should never be invoked on code meant to be run. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -#include "src/Dialect/Krnl/DialectBuilder.hpp" -#include "src/Dialect/Krnl/KrnlOps.hpp" -#include "src/Pass/Passes.hpp" -#include "src/Support/KrnlSupport.hpp" - -#include "ElideKrnlGlobalConstants.hpp" - -using namespace mlir; -using namespace onnx_mlir; - -constexpr uint64_t KrnlConstGlobalValueElision::kDefaultElisionThreshold; - -mlir::LogicalResult KrnlConstGlobalValueElision::matchAndRewrite( - mlir::KrnlGlobalOp op, mlir::PatternRewriter &rewriter) const { - Location loc = op.getLoc(); - - // Only elide if value is available. - if (!op.getValue().has_value()) - return success(); - - // Only elide dense and dense resource attributes. - if (!(op.getValue()->isa() || - op.getValue()->isa())) - return success(); - - MultiDialectBuilder create(rewriter, loc); - - bool elide = false; - - if (op.getValue()->isa()) { - const auto &valAttr = - op.getValueAttr().dyn_cast_or_null(); - if (valAttr.getNumElements() > elisionThreshold && !valAttr.isSplat()) { - elide = true; - } - } else { - const auto &valAttr = - op.getValueAttr().dyn_cast_or_null(); - if (valAttr.getNumElements() > elisionThreshold) { - elide = true; - } - } - - if (elide) { - IntegerAttr offsetAttr = op.getOffset() ? op.getOffsetAttr() : nullptr; - IntegerAttr alignmentAttr = - op.getAlignment() ? op.getAlignmentAttr() : nullptr; - auto newGlobalOp = - create.krnl.constant(op.getResult().getType().cast(), - op.getName(), std::nullopt, offsetAttr, alignmentAttr); - rewriter.replaceOp(op, newGlobalOp); - } - - return success(); -} - -namespace { -/*! - * Function pass that performs constant value elision of Krnl globals. - */ -class ElideConstGlobalValuePass : public PassWrapper> { -public: - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ElideConstGlobalValuePass) - - StringRef getArgument() const override { return "elide-krnl-constants"; } - - StringRef getDescription() const override { - return "Elide the constant values of the Global Krnl operations."; - } - - void runOnOperation() override { - auto function = getOperation(); - - ConversionTarget target(getContext()); - RewritePatternSet patterns(&getContext()); - patterns.insert( - &getContext(), KrnlConstGlobalValueElision::kDefaultElisionThreshold); - // No need to test, its ok to fail the apply. - LogicalResult res = - applyPatternsAndFoldGreedily(function, std::move(patterns)); - assert((succeeded(res) || failed(res)) && "remove unused var warning"); - } -}; - -} // namespace - -std::unique_ptr onnx_mlir::createElideConstGlobalValuePass() { - return std::make_unique(); -} diff --git a/src/Transform/ElideKrnlGlobalConstants.hpp b/src/Transform/ElideKrnlGlobalConstants.hpp deleted file mode 100644 index 9cdaf7c9b1..0000000000 --- a/src/Transform/ElideKrnlGlobalConstants.hpp +++ /dev/null @@ -1,39 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - */ - -#pragma once - -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" - -#include "src/Dialect/Krnl/KrnlOps.hpp" - -/*! - * RewritePattern that replaces existing constant Krnl global values - * with a similar operation which preserves all attributes except the value - * attribute. - */ -class KrnlConstGlobalValueElision - : public mlir::OpRewritePattern { -public: - /* - * A threshold value specifying the maximum number of elements a constant - * operation can hold as an attribute. If the number exceeds this threshold, - * constants will be packed together and, in the case where `move-to-file` - * option is enabled, stored as a binary file on disk. This can help preserve - * readability of IR dump and improve compilation speed. - */ - static constexpr uint64_t kDefaultElisionThreshold = 32; - - int64_t elisionThreshold; - - using mlir::OpRewritePattern::OpRewritePattern; - - explicit KrnlConstGlobalValueElision( - mlir::MLIRContext *context, int64_t elisionThreshold) - : OpRewritePattern(context), elisionThreshold(elisionThreshold) {} - - mlir::LogicalResult matchAndRewrite( - mlir::KrnlGlobalOp op, mlir::PatternRewriter &rewriter) const override; -}; diff --git a/test/mlir/krnl/krnl_global_elision.mlir b/test/mlir/krnl/krnl_global_elision.mlir deleted file mode 100644 index dc13ef5d60..0000000000 --- a/test/mlir/krnl/krnl_global_elision.mlir +++ /dev/null @@ -1,28 +0,0 @@ -// RUN: onnx-mlir-opt --elide-krnl-constants %s -split-input-file | FileCheck %s - -// CHECK-LABEL: func @test_elide_krnl_global_constant(%arg0: memref<1xf32>) -> memref<1x70xf32> -func.func @test_elide_krnl_global_constant(%arg0: memref<1xf32>) -> memref<1x70xf32> { - %0 = "krnl.global"() {name = "constant_0", shape = [1, 70], value = dense<[[0., 1.0, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]> : tensor<1x70xf32>} : () -> memref<1x70xf32> - return %0 : memref<1x70xf32> - - // CHECK: {{.*}} = "krnl.global"() {name = "constant_00", shape = [1, 70]} : () -> memref<1x70xf32> - // CHECK: return {{.*}} : memref<1x70xf32> -} - -// ----- - -func.func @test_elide_krnl_global_constant() -> memref<1x80xf32> { - %0 = "krnl.global"() {name = "constant_0", shape = [1, 80], value = dense_resource : tensor<1x80xf32>} : () -> memref<1x80xf32> - return %0 : memref<1x80xf32> - -// CHECK: {{.*}} = "krnl.global"() {name = "constant_01", shape = [1, 80]} : () -> memref<1x80xf32> -// CHECK: return {{.*}} : memref<1x80xf32> -} - -{-# - dialect_resources: { - builtin: { - hex_constant: "0x010000003F0000003F8000003FC00000400000003F0000003F8000003FC00000400000003F0000003F8000003FC00000400000003F0000003F8000003FC00000400000003F0000003F8000003FC00000400000003F0000003F8000003FC00000400000003F0000003F8000003FC00000400000003F0000003F8000003FC00000400000003F0000003F8000003FC00000400000003F0000003F8000003FC00000400000003F0000003F8000003FC00000400000003F0000003F8000003FC00000400000003F0000003F8000003FC00000400000003F0000003F8000003FC00000400000003F0000003F8000003FC00000400000003F0000003F8000003FC00000400000003F0000003F8000003FC00000400000003F0000003F8000003FC00000400000003F0000003F8000003FC00000400000003F0000003F8000003FC0000040000000" - } - } -#-}