Skip to content

Commit

Permalink
python: trim registration and loading of dialects and passes (#1084)
Browse files Browse the repository at this point in the history
In the interest of merging upstream LLVM quickly, a previous patch
(7f08169) updated the torch-mlir build to register all dialects and
passes through Python bindings.  This patch limits the dialects and
passes to only those that are used in torch-mlir.

Key to this change are the removal of
`MLIRPythonExtension.RegisterEverything` and the introduction of a new
Python module (`_mlir_libs/_site_initialize_0.py`), where we register
the dialects and passes used by torch-mlir.
  • Loading branch information
ashay authored Jul 21, 2022
1 parent c61c99e commit ad283c1
Show file tree
Hide file tree
Showing 8 changed files with 73 additions and 25 deletions.
3 changes: 3 additions & 0 deletions include/torch-mlir-c/Registration.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ extern "C" {
*/
MLIR_CAPI_EXPORTED void torchMlirRegisterAllDialects(MlirContext context);

/** Registers upstream (MLIR) dialects used in Torch-MLIR IRs. */
MLIR_CAPI_EXPORTED void torchMlirRegisterRequiredDialects(MlirContext context);

/** Registers all passes for symbolic access with the global registry. */
MLIR_CAPI_EXPORTED void torchMlirRegisterAllPasses();

Expand Down
17 changes: 17 additions & 0 deletions lib/CAPI/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,23 @@ add_mlir_public_c_api_library(TorchMLIRCAPI

LINK_LIBS PUBLIC
MLIRIR
MLIRAffineToStandard
MLIRArithmeticToLLVM
MLIRArithmeticTransforms
MLIRBufferizationTransforms
MLIRControlFlowToLLVM
MLIRFuncToLLVM
MLIRFuncTransforms
MLIRLinalgToLLVM
MLIRLinalgTransforms
MLIRMathToLLVM
MLIRMemRefToLLVM
MLIRReconcileUnrealizedCasts
MLIRSCFToControlFlow
MLIRSCFTransforms
MLIRTensorTransforms
MLIRTosaToArith
MLIRTosaToLinalg
MLIRSupport
TorchMLIRTorchDialect
TorchMLIRInitAll
Expand Down
44 changes: 40 additions & 4 deletions lib/CAPI/Registration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,27 @@
#include "torch-mlir-c/Registration.h"

#include "mlir/CAPI/IR.h"
#include "mlir/Conversion/Passes.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/InitAllPasses.h"
#include "torch-mlir/InitAll.h"

void torchMlirRegisterRequiredDialects(MlirContext context) {
mlir::DialectRegistry registry;
registry.insert<mlir::AffineDialect, mlir::arith::ArithmeticDialect,
mlir::bufferization::BufferizationDialect,
mlir::func::FuncDialect, mlir::linalg::LinalgDialect,
mlir::scf::SCFDialect, mlir::tensor::TensorDialect,
mlir::tosa::TosaDialect>();
unwrap(context)->appendDialectRegistry(registry);
}

void torchMlirRegisterAllDialects(MlirContext context) {
mlir::DialectRegistry registry;
mlir::torch::registerAllDialects(registry);
Expand All @@ -23,4 +39,24 @@ void torchMlirRegisterAllDialects(MlirContext context) {
unwrap(context)->loadAllAvailableDialects();
}

void torchMlirRegisterAllPasses() { mlir::torch::registerAllPasses(); }
void torchMlirRegisterAllPasses() {
mlir::arith::registerArithmeticPasses();
mlir::bufferization::registerBufferizationPasses();
mlir::func::registerFuncPasses();
mlir::registerConvertAffineToStandardPass();
mlir::registerConvertArithmeticToLLVMPass();
mlir::registerConvertControlFlowToLLVMPass();
mlir::registerConvertFuncToLLVMPass();
mlir::registerConvertLinalgToLLVMPass();
mlir::registerConvertMathToLLVMPass();
mlir::registerConvertMemRefToLLVMPass();
mlir::registerLinalgPasses();
mlir::registerReconcileUnrealizedCastsPass();
mlir::registerSCFPasses();
mlir::registerSCFToControlFlowPass();
mlir::registerTosaToArithPass();
mlir::registerTosaToLinalgNamedPass();
mlir::registerTosaToLinalgPass();
mlir::tensor::registerTensorPasses();
mlir::torch::registerAllPasses();
}
13 changes: 4 additions & 9 deletions python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ declare_mlir_python_sources(TorchMLIRPythonSources.TopLevel
SOURCES
__init__.py
compiler_utils.py
_mlir_libs/_site_initialize_0.py
)

declare_mlir_python_sources(TorchMLIRPythonSources.Dialects
Expand Down Expand Up @@ -102,16 +103,10 @@ add_subdirectory(torch_mlir/eager_mode)
################################################################################

set(_source_components
# TODO: Core is now implicitly building/registering all dialects, increasing
# build burden by ~5x. Make it stop.
# TODO: Reduce dependencies. We need ExecutionEngine and a bunch of passes
# for the reference backend, but logically they can be separate. But seemingly
# the only way to handle that is to create a separate mlir python package
# tree, which seems excessive.
MLIRPythonSources
MLIRPythonSources.Core
MLIRPythonSources.Dialects.func
MLIRPythonSources.ExecutionEngine
MLIRPythonExtension.Core
MLIRPythonExtension.RegisterEverything
MLIRPythonExtension.ExecutionEngine
TorchMLIRPythonSources
TorchMLIRPythonExtensions
)
Expand Down
13 changes: 3 additions & 10 deletions python/TorchMLIRModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "mlir/CAPI/IR.h"
#include "torch-mlir-c/Dialects.h"
#include "torch-mlir-c/Registration.h"

Expand All @@ -19,14 +20,6 @@ PYBIND11_MODULE(_torchMlir, m) {

m.doc() = "torch-mlir main python extension";

m.def(
"register_dialect",
[](MlirContext context, bool load) {
MlirDialectHandle handle = mlirGetDialectHandle__torch__();
mlirDialectHandleRegisterDialect(handle, context);
if (load) {
mlirDialectHandleLoadDialect(handle, context);
}
},
py::arg("context"), py::arg("load") = true);
m.def("register_required_dialects", torchMlirRegisterRequiredDialects,
py::arg("context"));
}
4 changes: 4 additions & 0 deletions python/torch_mlir/_mlir_libs/_site_initialize_0.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from . import _torchMlir

def context_init_hook(context):
_torchMlir.register_required_dialects(context)
2 changes: 1 addition & 1 deletion python/torch_mlir/dialects/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
# Also available under a BSD-style license. See LICENSE.

from .._torch_ops_gen import *
from ..._mlir_libs._torchMlir import register_dialect
from ..._mlir_libs._torchMlir import register_required_dialects
2 changes: 1 addition & 1 deletion test/python/smoketest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
from torch_mlir.dialects import torch

with torch_mlir.ir.Context() as ctx:
torch.register_dialect(ctx)
torch.register_required_dialects(ctx)

0 comments on commit ad283c1

Please sign in to comment.