From e2b136da793e4edc0f72953bed57b4d651815684 Mon Sep 17 00:00:00 2001 From: Hideto Ueno Date: Sat, 7 Sep 2024 14:19:39 +0900 Subject: [PATCH] [python] Add `walk_with_filter` to walk subset of IR This adds `walk_with_filter` python method to invoke callbacks only for subset of operations. We require GIL to call python function so this could improve performance of the tools based on Python API. ``` Starting walk_with_filter walk_with_filter elapsed time: 0.005462 seconds cnt=1 Starting operation.walk operation.walk elapsed time: 1.061360 seconds cnt=2 ``` --- .../Bindings/Python/support/walk.py | 76 ++++++++++++++++ lib/Bindings/Python/CIRCTModule.cpp | 4 +- .../{DialectModules.h => CIRCTModules.h} | 9 +- lib/Bindings/Python/CMakeLists.txt | 1 + lib/Bindings/Python/ESIModule.cpp | 2 +- lib/Bindings/Python/HWModule.cpp | 2 +- lib/Bindings/Python/MSFTModule.cpp | 2 +- lib/Bindings/Python/OMModule.cpp | 2 +- lib/Bindings/Python/SVModule.cpp | 2 +- lib/Bindings/Python/SeqModule.cpp | 2 +- lib/Bindings/Python/SupportModule.cpp | 86 +++++++++++++++++++ lib/Bindings/Python/support.py | 12 +++ 12 files changed, 189 insertions(+), 11 deletions(-) create mode 100644 integration_test/Bindings/Python/support/walk.py rename lib/Bindings/Python/{DialectModules.h => CIRCTModules.h} (76%) create mode 100644 lib/Bindings/Python/SupportModule.cpp diff --git a/integration_test/Bindings/Python/support/walk.py b/integration_test/Bindings/Python/support/walk.py new file mode 100644 index 000000000000..a7594cbf1cea --- /dev/null +++ b/integration_test/Bindings/Python/support/walk.py @@ -0,0 +1,76 @@ +# REQUIRES: bindings_python +# RUN: %PYTHON% %s | FileCheck %s + +import circt +from circt.support import walk_with_filter +from circt.dialects import hw +from circt.ir import Context, Module, WalkOrder, WalkResult + + +def test_walk_with_filter(): + ctx = Context() + circt.register_dialects(ctx) + module = Module.parse( + r""" + builtin.module { + hw.module @f() { + hw.output + } + } + """, + ctx, + ) + + def callback(op): + print(op.name) + return WalkResult.ADVANCE + + # Test post-order walk. + # CHECK: Post-order + # CHECK-NEXT: hw.output + # CHECK-NEXT: hw.module + # CHECK-NOT: builtin.module + print("Post-order") + walk_with_filter(module.operation, [hw.HWModuleOp, hw.OutputOp], callback, + WalkOrder.POST_ORDER) + + # Test pre-order walk. + # CHECK-NEXT: Pre-order + # CHECK-NOT: builtin.module + # CHECK-NEXT: hw.module + # CHECK-NEXT: hw.output + print("Pre-order") + walk_with_filter(module.operation, [hw.HWModuleOp, hw.OutputOp], callback, + WalkOrder.PRE_ORDER) + + # Test interrupt. + # CHECK-NEXT: Interrupt post-order + # CHECK-NEXT: hw.output + print("Interrupt post-order") + + def interrupt_callback(op): + print(op.name) + return WalkResult.INTERRUPT + + walk_with_filter(module.operation, [hw.OutputOp], interrupt_callback, + WalkOrder.POST_ORDER) + + # Test exception. + # CHECK: Exception + # CHECK-NEXT: hw.output + # CHECK-NEXT: Exception raised + print("Exception") + + def exception_callback(op): + print(op.name) + raise ValueError + return WalkResult.ADVANCE + + try: + walk_with_filter(module.operation, [hw.OutputOp], exception_callback, + WalkOrder.POST_ORDER) + except RuntimeError: + print("Exception raised") + + +test_walk_with_filter() diff --git a/lib/Bindings/Python/CIRCTModule.cpp b/lib/Bindings/Python/CIRCTModule.cpp index e86ec8a53ac9..8510a04cf69c 100644 --- a/lib/Bindings/Python/CIRCTModule.cpp +++ b/lib/Bindings/Python/CIRCTModule.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "DialectModules.h" +#include "CIRCTModules.h" #include "circt-c/Conversion.h" #include "circt-c/Dialect/Comb.h" @@ -142,4 +142,6 @@ PYBIND11_MODULE(_circt, m) { circt::python::populateDialectOMSubmodule(om); py::module sv = m.def_submodule("_sv", "SV API"); circt::python::populateDialectSVSubmodule(sv); + py::module support = m.def_submodule("_support", "CIRCT support"); + circt::python::populateSupportSubmodule(support); } diff --git a/lib/Bindings/Python/DialectModules.h b/lib/Bindings/Python/CIRCTModules.h similarity index 76% rename from lib/Bindings/Python/DialectModules.h rename to lib/Bindings/Python/CIRCTModules.h index 933aea340f8e..230b36c864f9 100644 --- a/lib/Bindings/Python/DialectModules.h +++ b/lib/Bindings/Python/CIRCTModules.h @@ -1,4 +1,4 @@ -//===- DialectModules.h - Populate submodules -----------------------------===// +//===- CIRCTModules.h - Populate submodules -------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,12 +6,12 @@ // //===----------------------------------------------------------------------===// // -// Functions to populate each dialect's submodule (if provided). +// Functions to populate submodules in CIRCT (if provided). // //===----------------------------------------------------------------------===// -#ifndef CIRCT_BINDINGS_PYTHON_DIALECTMODULES_H -#define CIRCT_BINDINGS_PYTHON_DIALECTMODULES_H +#ifndef CIRCT_BINDINGS_PYTHON_CIRCTMODULES_H +#define CIRCT_BINDINGS_PYTHON_CIRCTMODULES_H #include @@ -24,6 +24,7 @@ void populateDialectMSFTSubmodule(pybind11::module &m); void populateDialectOMSubmodule(pybind11::module &m); void populateDialectSeqSubmodule(pybind11::module &m); void populateDialectSVSubmodule(pybind11::module &m); +void populateSupportSubmodule(pybind11::module &m); } // namespace python } // namespace circt diff --git a/lib/Bindings/Python/CMakeLists.txt b/lib/Bindings/Python/CMakeLists.txt index e0c02b0e8608..8f2d07d85cc0 100644 --- a/lib/Bindings/Python/CMakeLists.txt +++ b/lib/Bindings/Python/CMakeLists.txt @@ -20,6 +20,7 @@ declare_mlir_python_extension(CIRCTBindingsPythonExtension OMModule.cpp MSFTModule.cpp SeqModule.cpp + SupportModule.cpp SVModule.cpp EMBED_CAPI_LINK_LIBS CIRCTCAPIComb diff --git a/lib/Bindings/Python/ESIModule.cpp b/lib/Bindings/Python/ESIModule.cpp index d143426be657..6a4d22efbc41 100644 --- a/lib/Bindings/Python/ESIModule.cpp +++ b/lib/Bindings/Python/ESIModule.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "DialectModules.h" +#include "CIRCTModules.h" #include "circt/Dialect/ESI/ESIDialect.h" diff --git a/lib/Bindings/Python/HWModule.cpp b/lib/Bindings/Python/HWModule.cpp index 1db80da6bff2..e6b8883066ff 100644 --- a/lib/Bindings/Python/HWModule.cpp +++ b/lib/Bindings/Python/HWModule.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "DialectModules.h" +#include "CIRCTModules.h" #include "circt-c/Dialect/HW.h" diff --git a/lib/Bindings/Python/MSFTModule.cpp b/lib/Bindings/Python/MSFTModule.cpp index 9691d5b95807..97acc2f80e17 100644 --- a/lib/Bindings/Python/MSFTModule.cpp +++ b/lib/Bindings/Python/MSFTModule.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "DialectModules.h" +#include "CIRCTModules.h" #include "circt-c/Dialect/MSFT.h" #include "circt/Dialect/MSFT/MSFTDialect.h" diff --git a/lib/Bindings/Python/OMModule.cpp b/lib/Bindings/Python/OMModule.cpp index 7908867d116a..8f9deb6808c9 100644 --- a/lib/Bindings/Python/OMModule.cpp +++ b/lib/Bindings/Python/OMModule.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "DialectModules.h" +#include "CIRCTModules.h" #include "circt-c/Dialect/HW.h" #include "circt-c/Dialect/OM.h" #include "mlir-c/BuiltinAttributes.h" diff --git a/lib/Bindings/Python/SVModule.cpp b/lib/Bindings/Python/SVModule.cpp index 12d75078df33..74b3df2f1d47 100644 --- a/lib/Bindings/Python/SVModule.cpp +++ b/lib/Bindings/Python/SVModule.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "DialectModules.h" +#include "CIRCTModules.h" #include "circt-c/Dialect/SV.h" #include "mlir-c/Bindings/Python/Interop.h" diff --git a/lib/Bindings/Python/SeqModule.cpp b/lib/Bindings/Python/SeqModule.cpp index 97d6ebcff968..3374b193b0da 100644 --- a/lib/Bindings/Python/SeqModule.cpp +++ b/lib/Bindings/Python/SeqModule.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "DialectModules.h" +#include "CIRCTModules.h" #include "circt-c/Dialect/Seq.h" #include "mlir/Bindings/Python/PybindAdaptors.h" diff --git a/lib/Bindings/Python/SupportModule.cpp b/lib/Bindings/Python/SupportModule.cpp new file mode 100644 index 000000000000..472b60339686 --- /dev/null +++ b/lib/Bindings/Python/SupportModule.cpp @@ -0,0 +1,86 @@ +//===- SupportModule.cpp - Support API pybind module ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "CIRCTModules.h" + +#include "mlir/Bindings/Python/PybindAdaptors.h" + +#include "PybindUtils.h" +#include "mlir-c/Support.h" +#include +#include +#include + +namespace py = pybind11; + +using namespace circt; +using namespace mlir::python::adaptors; + +/// Populate the support python module. +void circt::python::populateSupportSubmodule(py::module &m) { + m.doc() = "CIRCT Python utils"; + // Walk with filter. + m.def( + "_walk_with_filter", + [](MlirOperation operation, std::vector op_names, + std::function callback, + MlirWalkOrder walkOrder) { + struct UserData { + std::function callback; + bool gotException; + std::string exceptionWhat; + py::object exceptionType; + std::vector op_names; + }; + + std::vector op_names_identifiers; + + // Construct MlirIdentifier from string to perform pointer comparison. + for (auto &op_name : op_names) + op_names_identifiers.push_back(mlirIdentifierGet( + mlirOperationGetContext(operation), + mlirStringRefCreateFromCString(op_name.c_str()))); + + UserData userData{callback, false, {}, {}, op_names_identifiers}; + MlirOperationWalkCallback walkCallback = [](MlirOperation op, + void *userData) { + UserData *calleeUserData = static_cast(userData); + auto op_name = mlirOperationGetName(op); + + // Check if the operation name is in the filter. + bool inFilter = false; + for (auto &op_name_identifier : calleeUserData->op_names) { + if (mlirIdentifierEqual(op_name, op_name_identifier)) { + inFilter = true; + break; + } + } + + // If the operation name is not in the filter, skip it. + if (!inFilter) + return MlirWalkResult::MlirWalkResultAdvance; + + try { + return (calleeUserData->callback)(op); + } catch (py::error_already_set &e) { + calleeUserData->gotException = true; + calleeUserData->exceptionWhat = e.what(); + calleeUserData->exceptionType = e.type(); + return MlirWalkResult::MlirWalkResultInterrupt; + } + }; + mlirOperationWalk(operation, walkCallback, &userData, walkOrder); + if (userData.gotException) { + std::string message("Exception raised in callback: "); + message.append(userData.exceptionWhat); + throw std::runtime_error(message); + } + }, + py::arg("op"), py::arg("op_names"), py::arg("callback"), + py::arg("walk_order")); +} diff --git a/lib/Bindings/Python/support.py b/lib/Bindings/Python/support.py index 47f1a048fdc3..38ce7d998e86 100644 --- a/lib/Bindings/Python/support.py +++ b/lib/Bindings/Python/support.py @@ -4,6 +4,8 @@ from . import ir +from ._mlir_libs._circt._support import _walk_with_filter +from .ir import Operation from contextlib import AbstractContextManager from contextvars import ContextVar from typing import List @@ -409,3 +411,13 @@ def create_default_value(self, index, data_type, arg_name): def operation(self): """Get the operation associated with this builder.""" return self.opview.operation + + +# Helper function to walk operation with a filter on operation names. +# `op_views` is a list of operation views to visit. This is a wrapper +# around the C++ implementation of walk_with_filter. +def walk_with_filter(operation: Operation, op_views: List[ir.OpView], callback, + walk_order): + op_names_identifiers = [name.OPERATION_NAME for name in op_views] + return _walk_with_filter(operation, op_names_identifiers, callback, + walk_order)